Skip to content

[feat] Faster topk algorithm#3009

Merged
kahyunnam merged 11 commits intoflashinfer-ai:mainfrom
Aalanli:topk_general
Apr 20, 2026
Merged

[feat] Faster topk algorithm#3009
kahyunnam merged 11 commits intoflashinfer-ai:mainfrom
Aalanli:topk_general

Conversation

@Aalanli
Copy link
Copy Markdown
Contributor

@Aalanli Aalanli commented Apr 7, 2026

📌 Description

This PR implements a faster topk algorithm that uses sm90+ CTA clusters feature. This is a non-deterministic algorithm, but does not drop indices and instead overflows to global memory. Benchmark results show that it's faster than both the multi-cta topk algorithm and the filtering algorithm overall. The cases it's slower is when the overflow happens too much.
Note: Speedup is speedup of flashinfer vs torch, while Speedup Clusters vs. Default is speed up of this kernel over flashinfer.

====================================================================================================
top_k: Basic radix-based top-k selection (dtype=FP32, deterministic=False, pattern=random)
NOTE: default top-k sweep includes two extra large-batch/long-vocab stress cases beyond the original grid
====================================================================================================
 batch    seq_len      k |   FlashInfer   torch.topk    Speedup     Clusters  Speedup Clusters vs. Default
-------------------------------------------------------------------------------------------------------------------
     1        256    256 |         4.42us      20.64us      4.67x       1.50us                         2.94x
     1        512    256 |         5.79us      22.11us      3.82x       6.18us                         0.94x
     1        512    512 |         5.38us      21.28us      3.96x       1.54us                         3.50x
     1       1024    256 |        10.05us      27.68us      2.75x       6.40us                         1.57x
     1       1024    512 |         6.56us      24.51us      3.74x       6.40us                         1.02x
     1       1024   1024 |         5.12us      25.38us      4.96x       1.60us                         3.20x
     1       2048    256 |        10.08us      31.17us      3.09x       7.10us                         1.42x
     1       2048    512 |        11.01us      31.94us      2.90x       7.20us                         1.53x
     1       2048   1024 |         9.34us      32.03us      3.43x       7.07us                         1.32x
     1       2048   2048 |         5.79us      27.26us      4.71x       1.76us                         3.29x
     1       4096    256 |         9.12us      36.06us      3.95x       7.30us                         1.25x
     1       4096    512 |        11.42us      41.70us      3.65x       7.39us                         1.55x
     1       4096   1024 |        11.52us      41.15us      3.57x       7.58us                         1.52x
     1       4096   2048 |        10.22us      32.45us      3.17x       7.30us                         1.40x
     1       4096   4096 |         7.17us      35.78us      4.99x       2.02us                         3.56x
     1      16384    256 |        12.83us      78.53us      6.12x      11.49us                         1.12x
     1      16384    512 |        14.98us      89.09us      5.95x      12.24us                         1.22x
     1      16384   1024 |        15.95us      80.99us      5.08x      12.26us                         1.30x
     1      16384   2048 |        17.50us      93.33us      5.33x      12.51us                         1.40x
     1      16384   4096 |        21.28us      84.70us      3.98x      12.80us                         1.66x
     1      65536    256 |        33.44us      94.88us      2.84x      12.48us                         2.68x
     1      65536    512 |        34.24us      92.26us      2.69x      12.35us                         2.77x
     1      65536   1024 |        34.75us      91.52us      2.63x      12.42us                         2.80x
     1      65536   2048 |        36.58us      91.87us      2.51x      13.73us                         2.66x
     1      65536   4096 |        37.98us      92.03us      2.42x      13.92us                         2.73x
     1     131072    256 |        39.10us      99.81us      2.55x      14.69us                         2.66x
     1     131072    512 |        40.26us      98.13us      2.44x      14.72us                         2.73x
     1     131072   1024 |        39.81us      99.92us      2.51x      14.88us                         2.68x
     1     131072   2048 |        41.94us     101.50us      2.42x      14.98us                         2.80x
     1     131072   4096 |        43.62us     103.04us      2.36x      18.40us                         2.37x
     1     262144    256 |        43.46us     116.40us      2.68x      19.26us                         2.26x
     1     262144    512 |        44.46us     115.49us      2.60x      19.36us                         2.30x
     1     262144   1024 |        44.35us     114.14us      2.57x      19.34us                         2.29x
     1     262144   2048 |        45.89us     115.36us      2.51x      19.52us                         2.35x
     1     262144   4096 |        46.78us     116.83us      2.50x      19.71us                         2.37x
     1     524288    256 |        45.60us     152.27us      3.34x      28.00us                         1.63x
     1     524288    512 |        46.91us     146.75us      3.13x      28.19us                         1.66x
     1     524288   1024 |        46.34us     151.39us      3.27x      28.35us                         1.63x
     1     524288   2048 |        47.14us     149.55us      3.17x      28.58us                         1.65x
     1     524288   4096 |        47.76us     151.58us      3.17x      28.54us                         1.67x
    16        256    256 |         5.95us      20.99us      3.53x       1.50us                         3.96x
    16        512    256 |         9.98us      23.07us      2.31x       6.43us                         1.55x
    16        512    512 |         5.60us      22.82us      4.07x       1.50us                         3.72x
    16       1024    256 |        11.30us      28.86us      2.56x       6.59us                         1.71x
    16       1024    512 |         9.12us      26.37us      2.89x       6.62us                         1.38x
    16       1024   1024 |         6.21us      27.30us      4.40x       1.60us                         3.88x
    16       2048    256 |        11.65us      31.87us      2.74x       7.23us                         1.61x
    16       2048    512 |        10.62us      32.64us      3.07x       7.36us                         1.44x
    16       2048   1024 |        10.08us      30.27us      3.00x       7.30us                         1.38x
    16       2048   2048 |         6.75us      30.46us      4.51x       1.76us                         3.84x
    16       4096    256 |        11.55us      43.04us      3.73x       7.39us                         1.56x
    16       4096    512 |        12.26us      44.70us      3.65x       7.49us                         1.64x
    16       4096   1024 |        12.06us      42.21us      3.50x       7.63us                         1.58x
    16       4096   2048 |        12.10us      38.69us      3.20x       7.65us                         1.58x
    16       4096   4096 |         7.84us      41.50us      5.29x       2.02us                         3.89x
    16      16384    256 |        14.98us      92.72us      6.19x      14.18us                         1.06x
    16      16384    512 |        15.17us      96.22us      6.34x      15.38us                         0.99x
    16      16384   1024 |        17.73us      98.88us      5.58x      15.33us                         1.16x
    16      16384   2048 |        18.18us      98.50us      5.42x      15.81us                         1.15x
    16      16384   4096 |        21.63us     107.09us      4.95x      16.29us                         1.33x
    16      65536    256 |        27.73us     104.93us      3.78x      16.13us                         1.72x
    16      65536    512 |        32.16us     103.86us      3.23x      16.35us                         1.97x
    16      65536   1024 |        32.58us     103.84us      3.19x      16.16us                         2.02x
    16      65536   2048 |        35.90us     107.92us      3.01x      19.20us                         1.87x
    16      65536   4096 |        42.08us     110.78us      2.63x      19.62us                         2.15x
    16     131072    256 |        43.92us     115.10us      2.62x      20.96us                         2.10x
    16     131072    512 |        43.46us     112.86us      2.60x      21.23us                         2.05x
    16     131072   1024 |        53.47us     115.94us      2.17x      21.25us                         2.52x
    16     131072   2048 |        54.43us     116.69us      2.14x      21.25us                         2.56x
    16     131072   4096 |        48.22us     122.78us      2.55x      26.88us                         1.79x
    16     262144    256 |        49.12us     136.64us      2.78x      27.54us                         1.78x
    16     262144    512 |        49.33us     136.82us      2.77x      28.03us                         1.76x
    16     262144   1024 |        49.46us     138.70us      2.80x      28.03us                         1.76x
    16     262144   2048 |        50.59us     139.33us      2.75x      28.29us                         1.79x
    16     262144   4096 |        51.55us     142.85us      2.77x      28.77us                         1.79x
    16     524288    256 |        89.95us     181.86us      2.02x      42.38us                         2.12x
    16     524288    512 |        90.51us     178.05us      1.97x      43.17us                         2.10x
    16     524288   1024 |        90.75us     179.14us      1.97x      42.72us                         2.12x
    16     524288   2048 |        91.73us     182.66us      1.99x      43.39us                         2.11x
    16     524288   4096 |        93.60us     185.39us      1.98x      43.15us                         2.17x
    64        256    256 |         6.37us      22.05us      3.46x       1.50us                         4.24x
    64        512    256 |        10.53us      23.68us      2.25x       6.53us                         1.61x
    64        512    512 |         5.50us      23.55us      4.28x       1.54us                         3.58x
    64       1024    256 |        12.03us      31.09us      2.58x       6.70us                         1.79x
    64       1024    512 |        10.62us      27.30us      2.57x       6.77us                         1.57x
    64       1024   1024 |         6.66us      26.91us      4.04x       1.63us                         4.08x
    64       2048    256 |        11.34us      35.07us      3.09x       7.39us                         1.53x
    64       2048    512 |        10.88us      34.18us      3.14x       7.42us                         1.47x
    64       2048   1024 |        11.20us      32.42us      2.89x       7.52us                         1.49x
    64       2048   2048 |         7.39us      32.06us      4.34x       1.79us                         4.12x
    64       4096    256 |        12.67us      43.23us      3.41x       7.68us                         1.65x
    64       4096    512 |        11.71us      43.14us      3.68x       7.71us                         1.52x
    64       4096   1024 |        12.13us      46.66us      3.85x       7.87us                         1.54x
    64       4096   2048 |        11.97us      40.22us      3.36x       7.78us                         1.54x
    64       4096   4096 |         8.64us      40.93us      4.74x       2.08us                         4.15x
    64      16384    256 |        15.42us      98.59us      6.39x      14.11us                         1.09x
    64      16384    512 |        15.39us     101.60us      6.60x      15.42us                         1.00x
    64      16384   1024 |        17.98us      99.87us      5.55x      15.65us                         1.15x
    64      16384   2048 |        19.14us     102.14us      5.34x      15.97us                         1.20x
    64      16384   4096 |        22.37us     104.22us      4.66x      16.46us                         1.36x
    64      65536    256 |        27.97us     135.71us      4.85x      20.93us                         1.34x
    64      65536    512 |        32.85us     136.61us      4.16x      21.12us                         1.56x
    64      65536   1024 |        34.08us     135.14us      3.97x      21.22us                         1.61x
    64      65536   2048 |        36.93us     139.07us      3.77x      27.04us                         1.37x
    64      65536   4096 |        43.30us     141.86us      3.28x      28.35us                         1.53x
    64     131072    256 |        44.80us     178.22us      3.98x      29.41us                         1.52x
    64     131072    512 |        44.53us     174.94us      3.93x      29.57us                         1.51x
    64     131072   1024 |        54.59us     174.56us      3.20x      29.70us                         1.84x
    64     131072   2048 |        55.46us     177.73us      3.20x      30.24us                         1.83x
    64     131072   4096 |        90.98us     181.47us      1.99x      44.93us                         2.02x
    64     262144    256 |        75.17us     240.90us      3.20x      47.73us                         1.57x
    64     262144    512 |        80.53us     239.86us      2.98x      47.92us                         1.68x
    64     262144   1024 |        81.97us     241.82us      2.95x      48.13us                         1.70x
    64     262144   2048 |        98.64us     241.71us      2.45x      48.58us                         2.03x
    64     262144   4096 |       143.36us     244.83us      1.71x      49.95us                         2.87x
    64     524288    256 |       162.29us     488.82us      3.01x      84.50us                         1.92x
    64     524288    512 |       161.41us     488.70us      3.03x      84.70us                         1.91x
    64     524288   1024 |       175.39us     491.07us      2.80x      85.01us                         2.06x
    64     524288   2048 |       175.98us     494.22us      2.81x      85.89us                         2.05x
    64     524288   4096 |       227.90us     494.98us      2.17x      88.62us                         2.57x
   256        256    256 |         7.04us      23.15us      3.29x       1.82us                         3.86x
   256        512    256 |        15.12us      27.04us      1.79x       7.87us                         1.92x
   256        512    512 |         7.15us      26.37us      3.69x       1.89us                         3.79x
   256       1024    256 |        17.66us      40.45us      2.29x       8.26us                         2.14x
   256       1024    512 |        15.65us      37.55us      2.40x       8.32us                         1.88x
   256       1024   1024 |         7.58us      38.83us      5.12x       2.14us                         3.54x
   256       2048    256 |        18.66us      49.60us      2.66x       9.34us                         2.00x
   256       2048    512 |        18.66us      51.17us      2.74x       9.47us                         1.97x
   256       2048   1024 |        17.09us      46.14us      2.70x       9.66us                         1.77x
   256       2048   2048 |         8.99us      47.52us      5.28x       2.37us                         3.80x
   256       4096    256 |        19.23us      99.17us      5.16x      10.72us                         1.79x
   256       4096    512 |        19.97us      99.41us      4.98x      10.91us                         1.83x
   256       4096   1024 |        20.58us     102.10us      4.96x      11.28us                         1.82x
   256       4096   2048 |        18.34us     103.58us      5.65x      11.07us                         1.66x
   256       4096   4096 |        12.54us     114.22us      9.11x       2.82us                         4.45x
   256      16384    256 |        28.42us     131.98us      4.64x      16.19us                         1.75x
   256      16384    512 |        29.50us     135.30us      4.59x      22.11us                         1.33x
   256      16384   1024 |        31.34us     136.93us      4.37x      23.23us                         1.35x
   256      16384   2048 |        34.24us     138.99us      4.06x      24.26us                         1.41x
   256      16384   4096 |        40.64us     149.52us      3.68x      25.12us                         1.62x
   256      65536    256 |        53.92us     240.27us      4.46x      42.66us                         1.26x
   256      65536    512 |        64.16us     243.09us      3.79x      43.22us                         1.48x
   256      65536   1024 |        64.93us     243.83us      3.76x      43.97us                         1.48x
   256      65536   2048 |        71.01us     250.61us      3.53x      76.10us                         0.93x
   256      65536   4096 |       150.53us     260.10us      1.73x      83.87us                         1.79x
   256     131072    256 |        92.15us     487.68us      5.29x      81.01us                         1.14x
   256     131072    512 |        92.16us     488.58us      5.30x      81.44us                         1.13x
   256     131072   1024 |       111.23us     493.68us      4.44x      82.45us                         1.35x
   256     131072   2048 |       112.56us     495.41us      4.40x      84.58us                         1.33x
   256     131072   4096 |       259.01us     509.48us      1.97x     161.81us                         1.60x
   256     262144    256 |       178.59us     881.03us      4.93x     155.81us                         1.15x
   256     262144    512 |       192.08us     879.04us      4.58x     156.94us                         1.22x
   256     262144   1024 |       192.74us     879.60us      4.56x     158.18us                         1.22x
   256     262144   2048 |       230.40us     885.49us      3.84x     160.82us                         1.43x
   256     262144   4096 |       411.46us     898.68us      2.18x     170.59us                         2.41x
   256     524288    256 |       327.39us    1542.01us      4.71x     299.43us                         1.09x
   256     524288    512 |       327.95us    1542.65us      4.70x     300.77us                         1.09x
   256     524288   1024 |       355.35us    1544.52us      4.35x     303.06us                         1.17x
   256     524288   2048 |       357.19us    1549.56us      4.34x     305.75us                         1.17x
   256     524288   4096 |       824.29us    1568.89us      1.90x     316.91us                         2.60x
  2048     131072   1024 |       703.20us    3012.43us      4.28x     470.11us                         1.50x
  4096     200000   1024 |      1994.76us    8891.66us      4.46x    1383.16us                         1.44x

====================================================================================================
dsa_topk: DeepSeek DSA-like indexer top-k workload (dtype=FP32, deterministic=False, dsa_pattern=dsa_relu, k=2048)
====================================================================================================
                    case     rows    seq_len      k |   FlashInfer   torch.topk    Speedup     Clusters  Speedup Clusters vs. Default
---------------------------------------------------------------------------------------------------------------------------------
      decode_b1_q1_l128k        1     131072   2048 |      42.59us      98.62us      2.32x      14.98us                         2.84x
       decode_b8_q1_l64k        8      65536   2048 |      37.01us      97.70us      2.64x      14.59us                         2.54x
     decode_b32_q1_l128k       32     131072   2048 |      56.13us     135.20us      2.41x      23.01us                         2.44x
   prefill_b1_q128_l128k      128     131072   2048 |      62.02us     241.52us      3.89x      49.18us                         1.26x

====================================================================================================
top_k_page_table_transform: Fused top-k + page table gather (dtype=FP32, deterministic=False, pattern=random)
====================================================================================================
 batch    seq_len      k |   FlashInfer     Clusters  Speedup Clusters vs. Default
-------------------------------------------------------------------------------------------------------------
     1        256    256 |       2.85us       2.43us                         1.17x
     1        512    256 |       4.72us       7.04us                         0.67x
     1        512    512 |       2.94us       2.56us                         1.15x
     1       1024    256 |       7.04us       7.01us                         1.00x
     1       1024    512 |       4.45us       6.77us                         0.66x
     1       1024   1024 |       2.56us       2.24us                         1.14x
     1       2048    256 |       8.64us       7.92us                         1.09x
     1       2048    512 |       7.38us       7.87us                         0.94x
     1       2048   1024 |       7.12us       7.71us                         0.92x
     1       2048   2048 |       3.26us       2.91us                         1.12x
     1       4096    256 |       8.93us       8.26us                         1.08x
     1       4096    512 |       9.06us       8.35us                         1.08x
     1       4096   1024 |       7.95us       8.32us                         0.96x
     1       4096   2048 |       7.65us       8.00us                         0.96x
     1       4096   4096 |       4.16us       2.62us                         1.59x
     1      16384    256 |      12.26us      12.70us                         0.96x
     1      16384    512 |      13.10us      13.26us                         0.99x
     1      16384   1024 |      14.46us      13.50us                         1.07x
     1      16384   2048 |      15.46us      13.54us                         1.14x
     1      16384   4096 |      20.35us      13.73us                         1.48x
     1      65536    256 |      30.30us      13.07us                         2.32x
     1      65536    512 |      31.14us      13.12us                         2.37x
     1      65536   1024 |      31.58us      13.18us                         2.40x
     1      65536   2048 |      33.50us      14.62us                         2.29x
     1      65536   4096 |      35.33us      14.82us                         2.38x
     1     131072    256 |      35.42us      15.07us                         2.35x
     1     131072    512 |      35.87us      15.17us                         2.36x
     1     131072   1024 |      36.45us      15.33us                         2.38x
     1     131072   2048 |      38.05us      15.30us                         2.49x
     1     131072   4096 |      41.22us      19.17us                         2.15x
     1     262144    256 |      40.08us      19.78us                         2.03x
     1     262144    512 |      40.58us      20.32us                         2.00x
     1     262144   1024 |      41.15us      19.81us                         2.08x
     1     262144   2048 |      42.82us      20.02us                         2.14x
     1     262144   4096 |      45.06us      20.22us                         2.23x
     1     524288    256 |      42.75us      28.21us                         1.52x
     1     524288    512 |      42.82us      28.29us                         1.51x
     1     524288   1024 |      43.10us      28.26us                         1.53x
     1     524288   2048 |      43.71us      28.38us                         1.54x
     1     524288   4096 |      45.12us      29.31us                         1.54x
    16        256    256 |       2.75us       2.37us                         1.16x
    16        512    256 |       7.71us       7.17us                         1.08x
    16        512    512 |       2.85us       2.46us                         1.16x
    16       1024    256 |       8.19us       7.30us                         1.12x
    16       1024    512 |       6.94us       7.39us                         0.94x
    16       1024   1024 |       2.82us       2.56us                         1.10x
    16       2048    256 |       8.70us       8.08us                         1.08x
    16       2048    512 |       8.67us       8.19us                         1.06x
    16       2048   1024 |       8.19us       8.06us                         1.02x
    16       2048   2048 |       3.52us       3.20us                         1.10x
    16       4096    256 |       9.05us       8.45us                         1.07x
    16       4096    512 |      10.14us       8.50us                         1.19x
    16       4096   1024 |       9.31us       8.61us                         1.08x
    16       4096   2048 |       8.74us       8.35us                         1.05x
    16       4096   4096 |       4.50us       2.82us                         1.60x
    16      16384    256 |      12.74us      15.04us                         0.85x
    16      16384    512 |      13.38us      16.83us                         0.79x
    16      16384   1024 |      14.30us      16.80us                         0.85x
    16      16384   2048 |      16.26us      17.17us                         0.95x
    16      16384   4096 |      20.51us      17.52us                         1.17x
    16      65536    256 |      24.69us      16.86us                         1.46x
    16      65536    512 |      29.82us      16.80us                         1.78x
    16      65536   1024 |      30.69us      17.18us                         1.79x
    16      65536   2048 |      34.14us      20.29us                         1.68x
    16      65536   4096 |      39.68us      20.83us                         1.90x
    16     131072    256 |      40.78us      21.63us                         1.89x
    16     131072    512 |      41.25us      21.63us                         1.91x
    16     131072   1024 |      51.58us      21.66us                         2.38x
    16     131072   2048 |      52.78us      22.02us                         2.40x
    16     131072   4096 |      45.87us      27.94us                         1.64x
    16     262144    256 |      45.34us      28.64us                         1.58x
    16     262144    512 |      45.92us      28.70us                         1.60x
    16     262144   1024 |      46.34us      29.04us                         1.60x
    16     262144   2048 |      47.84us      29.60us                         1.62x
    16     262144   4096 |      49.70us      30.30us                         1.64x
    16     524288    256 |      86.62us      43.87us                         1.97x
    16     524288    512 |      87.30us      44.13us                         1.98x
    16     524288   1024 |      87.68us      44.35us                         1.98x
    16     524288   2048 |      89.30us      44.74us                         2.00x
    16     524288   4096 |      92.38us      45.36us                         2.04x
    64        256    256 |       2.85us       2.43us                         1.17x
    64        512    256 |       7.78us       7.39us                         1.05x
    64        512    512 |       2.91us       2.56us                         1.14x
    64       1024    256 |       8.58us       7.58us                         1.13x
    64       1024    512 |       8.06us       7.60us                         1.06x
    64       1024   1024 |       3.01us       2.66us                         1.13x
    64       2048    256 |       9.47us       8.38us                         1.13x
    64       2048    512 |       8.83us       8.35us                         1.06x
    64       2048   1024 |       8.29us       8.35us                         0.99x
    64       2048   2048 |       3.71us       3.33us                         1.12x
    64       4096    256 |       9.41us       8.74us                         1.08x
    64       4096    512 |       9.58us       8.86us                         1.08x
    64       4096   1024 |      10.34us       8.93us                         1.16x
    64       4096   2048 |       9.09us       8.61us                         1.06x
    64       4096   4096 |       4.70us       3.04us                         1.55x
    64      16384    256 |      13.06us      15.17us                         0.86x
    64      16384    512 |      13.86us      16.61us                         0.83x
    64      16384   1024 |      15.23us      16.80us                         0.91x
    64      16384   2048 |      16.45us      16.96us                         0.97x
    64      16384   4096 |      21.10us      17.44us                         1.21x
    64      65536    256 |      25.58us      21.89us                         1.17x
    64      65536    512 |      31.20us      22.02us                         1.42x
    64      65536   1024 |      31.87us      22.18us                         1.44x
    64      65536   2048 |      35.34us      28.45us                         1.24x
    64      65536   4096 |      41.31us      30.75us                         1.34x
    64     131072    256 |      42.14us      30.34us                         1.39x
    64     131072    512 |      42.80us      30.59us                         1.40x
    64     131072   1024 |      52.54us      30.85us                         1.70x
    64     131072   2048 |      54.40us      31.62us                         1.72x
    64     131072   4096 |      90.27us      49.18us                         1.84x
    64     262144    256 |      71.92us      48.66us                         1.48x
    64     262144    512 |      79.04us      48.56us                         1.63x
    64     262144   1024 |      79.22us      49.02us                         1.62x
    64     262144   2048 |      96.82us      49.54us                         1.95x
    64     262144   4096 |     146.72us      54.74us                         2.68x
    64     524288    256 |     158.80us      84.70us                         1.87x
    64     524288    512 |     158.82us      84.99us                         1.87x
    64     524288   1024 |     171.39us      85.76us                         2.00x
    64     524288   2048 |     172.91us      86.19us                         2.01x
    64     524288   4096 |     228.77us      90.02us                         2.54x
   256        256    256 |       3.74us       2.75us                         1.36x
   256        512    256 |      13.25us       8.72us                         1.52x
   256        512    512 |       3.81us       2.98us                         1.28x
   256       1024    256 |      15.84us       9.09us                         1.74x
   256       1024    512 |      13.47us       9.14us                         1.47x
   256       1024   1024 |       4.08us       3.36us                         1.21x
   256       2048    256 |      16.51us      10.29us                         1.60x
   256       2048    512 |      16.32us      10.37us                         1.57x
   256       2048   1024 |      14.53us      10.30us                         1.41x
   256       2048   2048 |       5.28us       3.97us                         1.33x
   256       4096    256 |      17.36us      11.82us                         1.47x
   256       4096    512 |      18.72us      11.97us                         1.56x
   256       4096   1024 |      18.43us      12.22us                         1.51x
   256       4096   2048 |      15.68us      11.42us                         1.37x
   256       4096   4096 |       7.52us       4.29us                         1.75x
   256      16384    256 |      26.72us      17.09us                         1.56x
   256      16384    512 |      28.35us      23.04us                         1.23x
   256      16384   1024 |      30.21us      23.97us                         1.26x
   256      16384   2048 |      33.22us      24.93us                         1.33x
   256      16384   4096 |      41.31us      26.30us                         1.57x
   256      65536    256 |      52.75us      43.46us                         1.21x
   256      65536    512 |      63.33us      44.13us                         1.44x
   256      65536   1024 |      66.08us      46.93us                         1.41x
   256      65536   2048 |      74.27us      78.90us                         0.94x
   256      65536   4096 |     152.87us      88.19us                         1.73x
   256     131072    256 |      89.89us      80.86us                         1.11x
   256     131072    512 |      90.99us      81.34us                         1.12x
   256     131072   1024 |     109.87us      82.70us                         1.33x
   256     131072   2048 |     114.05us      85.18us                         1.34x
   256     131072   4096 |     261.31us     162.07us                         1.61x
   256     262144    256 |     174.27us     154.00us                         1.13x
   256     262144    512 |     187.58us     154.53us                         1.21x
   256     262144   1024 |     189.57us     155.97us                         1.22x
   256     262144   2048 |     227.89us     158.94us                         1.43x
   256     262144   4096 |     422.98us     169.30us                         2.50x
   256     524288    256 |     322.55us     298.26us                         1.08x
   256     524288    512 |     323.62us     300.18us                         1.08x
   256     524288   1024 |     350.56us     302.34us                         1.16x
   256     524288   2048 |     354.67us     306.77us                         1.16x
   256     524288   4096 |     834.52us     316.07us                         2.64x

====================================================================================================
top_k_ragged_transform: Fused top-k + ragged index transform (dtype=FP32, deterministic=False, pattern=random)
====================================================================================================
 batch    seq_len      k |   FlashInfer     Clusters  Speedup Clusters vs. Default
-------------------------------------------------------------------------------------------------------------
     1        256    256 |       2.37us       1.76us                         1.35x
     1        512    256 |       3.58us       6.53us                         0.55x
     1        512    512 |       2.34us       1.82us                         1.28x
     1       1024    256 |       7.63us       6.94us                         1.10x
     1       1024    512 |       6.14us       6.88us                         0.89x
     1       1024   1024 |       2.75us       2.08us                         1.32x
     1       2048    256 |       6.72us       7.52us                         0.89x
     1       2048    512 |       7.97us       7.74us                         1.03x
     1       2048   1024 |       7.71us       7.68us                         1.00x
     1       2048   2048 |       2.85us       2.18us                         1.31x
     1       4096    256 |       8.22us       7.70us                         1.07x
     1       4096    512 |       7.23us       7.71us                         0.94x
     1       4096   1024 |       8.38us       7.86us                         1.07x
     1       4096   2048 |       6.83us       7.44us                         0.92x
     1       4096   4096 |       2.91us       2.30us                         1.26x
     1      16384    256 |      11.52us      11.87us                         0.97x
     1      16384    512 |      11.94us      12.51us                         0.95x
     1      16384   1024 |      12.74us      12.54us                         1.02x
     1      16384   2048 |      13.92us      12.51us                         1.11x
     1      16384   4096 |      17.50us      12.61us                         1.39x
     1      65536    256 |      30.88us      12.78us                         2.42x
     1      65536    512 |      31.14us      12.74us                         2.44x
     1      65536   1024 |      32.16us      12.74us                         2.53x
     1      65536   2048 |      33.07us      14.05us                         2.35x
     1      65536   4096 |      34.18us      14.21us                         2.41x
     1     131072    256 |      36.54us      14.94us                         2.45x
     1     131072    512 |      37.22us      14.91us                         2.50x
     1     131072   1024 |      37.41us      14.98us                         2.50x
     1     131072   2048 |      38.34us      14.98us                         2.56x
     1     131072   4096 |      39.84us      18.67us                         2.13x
     1     262144    256 |      41.28us      19.36us                         2.13x
     1     262144    512 |      41.68us      19.65us                         2.12x
     1     262144   1024 |      41.44us      19.55us                         2.12x
     1     262144   2048 |      42.88us      19.65us                         2.18x
     1     262144   4096 |      43.46us      19.49us                         2.23x
     1     524288    256 |      42.88us      28.00us                         1.53x
     1     524288    512 |      43.33us      27.97us                         1.55x
     1     524288   1024 |      43.49us      28.54us                         1.52x
     1     524288   2048 |      44.10us      28.46us                         1.55x
     1     524288   4096 |      44.13us      28.35us                         1.56x
    16        256    256 |       2.43us       1.86us                         1.31x
    16        512    256 |       6.42us       6.78us                         0.95x
    16        512    512 |       2.48us       1.92us                         1.29x
    16       1024    256 |       7.52us       6.85us                         1.10x
    16       1024    512 |       6.18us       7.01us                         0.88x
    16       1024   1024 |       2.43us       1.92us                         1.27x
    16       2048    256 |       7.87us       7.58us                         1.04x
    16       2048    512 |       7.90us       7.60us                         1.04x
    16       2048   1024 |       6.56us       7.65us                         0.86x
    16       2048   2048 |       2.53us       2.08us                         1.22x
    16       4096    256 |       8.16us       7.78us                         1.05x
    16       4096    512 |       8.35us       7.82us                         1.07x
    16       4096   1024 |       8.34us       8.00us                         1.04x
    16       4096   2048 |       7.97us       7.58us                         1.05x
    16       4096   4096 |       2.62us       2.21us                         1.19x
    16      16384    256 |      11.71us      14.37us                         0.82x
    16      16384    512 |      13.09us      15.65us                         0.84x
    16      16384   1024 |      13.12us      15.97us                         0.82x
    16      16384   2048 |      14.27us      15.84us                         0.90x
    16      16384   4096 |      17.60us      16.10us                         1.09x
    16      65536    256 |      23.79us      16.37us                         1.45x
    16      65536    512 |      28.91us      16.48us                         1.75x
    16      65536   1024 |      29.57us      16.74us                         1.77x
    16      65536   2048 |      32.03us      18.98us                         1.69x
    16      65536   4096 |      37.66us      19.46us                         1.94x
    16     131072    256 |      39.75us      20.67us                         1.92x
    16     131072    512 |      40.06us      20.80us                         1.93x
    16     131072   1024 |      49.31us      20.61us                         2.39x
    16     131072   2048 |      49.25us      20.48us                         2.40x
    16     131072   4096 |      43.62us      26.45us                         1.65x
    16     262144    256 |      45.54us      27.57us                         1.65x
    16     262144    512 |      45.63us      27.74us                         1.64x
    16     262144   1024 |      46.02us      27.65us                         1.66x
    16     262144   2048 |      46.43us      27.76us                         1.67x
    16     262144   4096 |      47.68us      28.23us                         1.69x
    16     524288    256 |      86.43us      42.54us                         2.03x
    16     524288    512 |      86.82us      42.75us                         2.03x
    16     524288   1024 |      87.07us      42.08us                         2.07x
    16     524288   2048 |      87.71us      42.38us                         2.07x
    16     524288   4096 |      89.17us      42.93us                         2.08x
    64        256    256 |       2.37us       1.79us                         1.32x
    64        512    256 |       7.46us       7.01us                         1.06x
    64        512    512 |       2.56us       1.95us                         1.31x
    64       1024    256 |       7.74us       7.14us                         1.09x
    64       1024    512 |       7.33us       7.10us                         1.03x
    64       1024   1024 |       2.53us       2.05us                         1.23x
    64       2048    256 |       8.00us       7.68us                         1.04x
    64       2048    512 |       8.03us       7.78us                         1.03x
    64       2048   1024 |       6.78us       7.71us                         0.88x
    64       2048   2048 |       2.59us       2.11us                         1.23x
    64       4096    256 |       9.20us       7.98us                         1.15x
    64       4096    512 |       8.54us       8.10us                         1.06x
    64       4096   1024 |       8.64us       8.13us                         1.06x
    64       4096   2048 |       8.29us       7.84us                         1.06x
    64       4096   4096 |       2.72us       2.27us                         1.20x
    64      16384    256 |      12.13us      14.11us                         0.86x
    64      16384    512 |      13.31us      15.36us                         0.87x
    64      16384   1024 |      14.27us      15.63us                         0.91x
    64      16384   2048 |      15.17us      15.74us                         0.96x
    64      16384   4096 |      18.10us      16.03us                         1.13x
    64      65536    256 |      25.06us      20.64us                         1.21x
    64      65536    512 |      29.25us      20.58us                         1.42x
    64      65536   1024 |      29.89us      20.67us                         1.45x
    64      65536   2048 |      32.94us      26.56us                         1.24x
    64      65536   4096 |      38.99us      27.26us                         1.43x
    64     131072    256 |      41.06us      28.88us                         1.42x
    64     131072    512 |      41.22us      28.99us                         1.42x
    64     131072   1024 |      50.37us      29.09us                         1.73x
    64     131072   2048 |      50.83us      29.22us                         1.74x
    64     131072   4096 |      86.02us      43.84us                         1.96x
    64     262144    256 |      70.35us      47.04us                         1.50x
    64     262144    512 |      76.61us      47.07us                         1.63x
    64     262144   1024 |      77.41us      47.26us                         1.64x
    64     262144   2048 |      93.22us      47.54us                         1.96x
    64     262144   4096 |     137.71us      47.97us                         2.87x
    64     524288    256 |     157.73us      83.46us                         1.89x
    64     524288    512 |     157.31us      83.46us                         1.88x
    64     524288   1024 |     169.62us      83.81us                         2.02x
    64     524288   2048 |     170.10us      84.21us                         2.02x
    64     524288   4096 |     222.91us      84.21us                         2.65x
   256        256    256 |       2.94us       2.22us                         1.32x
   256        512    256 |      12.10us       8.16us                         1.48x
   256        512    512 |       3.09us       2.27us                         1.36x
   256       1024    256 |      14.05us       8.67us                         1.62x
   256       1024    512 |      11.49us       8.67us                         1.32x
   256       1024   1024 |       3.39us       2.45us                         1.39x
   256       2048    256 |      15.84us       9.60us                         1.65x
   256       2048    512 |      14.88us       9.66us                         1.54x
   256       2048   1024 |      13.15us       9.76us                         1.35x
   256       2048   2048 |       3.74us       2.46us                         1.52x
   256       4096    256 |      16.48us      10.98us                         1.50x
   256       4096    512 |      16.16us      11.10us                         1.46x
   256       4096   1024 |      16.93us      11.33us                         1.49x
   256       4096   2048 |      14.26us      10.72us                         1.33x
   256       4096   4096 |       4.13us       2.66us                         1.55x
   256      16384    256 |      24.54us      16.35us                         1.50x
   256      16384    512 |      25.57us      22.05us                         1.16x
   256      16384   1024 |      27.52us      22.66us                         1.21x
   256      16384   2048 |      29.82us      23.07us                         1.29x
   256      16384   4096 |      35.14us      23.14us                         1.52x
   256      65536    256 |      50.02us      42.40us                         1.18x
   256      65536    512 |      59.97us      42.64us                         1.41x
   256      65536   1024 |      60.35us      43.10us                         1.40x
   256      65536   2048 |      65.46us      72.35us                         0.90x
   256      65536   4096 |     144.24us      74.11us                         1.95x
   256     131072    256 |      87.68us      79.49us                         1.10x
   256     131072    512 |      88.13us      79.78us                         1.10x
   256     131072   1024 |     106.06us      80.19us                         1.32x
   256     131072   2048 |     106.21us      80.48us                         1.32x
   256     131072   4096 |     250.90us     143.89us                         1.74x
   256     262144    256 |     172.16us     152.48us                         1.13x
   256     262144    512 |     185.03us     152.80us                         1.21x
   256     262144   1024 |     185.33us     153.22us                         1.21x
   256     262144   2048 |     220.74us     153.57us                         1.44x
   256     262144   4096 |     403.57us     154.13us                         2.62x
   256     524288    256 |     320.91us     295.36us                         1.09x
   256     524288    512 |     320.71us     296.06us                         1.08x
   256     524288   1024 |     346.91us     296.64us                         1.17x
   256     524288   2048 |     347.35us     297.58us                         1.17x
   256     524288   4096 |     815.83us     298.11us                         2.74x


🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added a clustered Top‑K option with public APIs, automatic dispatch, and support for page‑table and ragged inputs.
  • Benchmarks

    • Added "clusters" measurement path, CUDA-graph timing, and speedup-vs-default reporting in benchmark output.
  • Tests

    • New correctness and coverage tests for clustered Top‑K gated by compute capability with dtype-specific accuracy thresholds.
  • Chores

    • Updated JIT/build to include clustered kernels and pass CUDA compile flags; added a cached shared‑memory query utility.
  • Style

    • Tightened test assertions, improved GPU timing sync, and adjusted benchmark output formatting.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 7, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds an exact clustered Top‑K implementation: new CUDA headers/kernels and TVM bindings, JIT/build and Python dispatch updates to use clustered kernels, benchmark/test updates (including CUDA graph & CUPTI timing tweak), and a small cached device utility for shared‑memory opt‑in.

Changes

Cohort / File(s) Summary
Clustered CUDA kernels & traits
include/flashinfer/fast_topk_clusters_exact.cuh, include/flashinfer/topk_common.cuh, include/flashinfer/topk.cuh
Add SM9+ cooperative-cluster exact Top‑K device routine and three kernel wrappers, host launchers/specializations, overflow caching, and move RadixTopKTraits into topk_common.cuh (remove old in-file traits).
TVM / CUDA binding
csrc/flashinfer_fast_topk_clusters_binding.cu
New TVM-FFI entrypoints fast_topk_clusters_exact* with runtime shape/dtype checks, stride extraction, optional histogram/value pointer handling, stream acquisition, kernel dispatch, and post-launch CUDA error checks.
Python JIT & build spec
flashinfer/jit/topk.py
Include new CUDA binding source in JIT inputs and add extra_cuda_cflags=["-lineinfo"] to the generated module spec.
Python API & runtime dispatch
flashinfer/topk.py
Add clustered wrapper functions, heuristics (get_fast_topk_clusters, roundup_kbyte), can_use_clusters_topk gate, shared‑memory opt‑in use, and route top_k/page-table/ragged transforms to clustered kernels when appropriate (with sorting/gather for sorted=True).
Benchmarks
benchmarks/bench_topk.py
Enable CUDA Graph capture for median timing paths, add "default" and new "clusters" measurement phases (record fast_topk_us and speedup_vs_flashinfer), pass enable_cupti=True/use_cuda_graph=True for cluster runs, and extend CSV/console headers and per-case output formatting.
CUPTI timing util
flashinfer/testing/utils.py
Insert torch.cuda.synchronize() inside the per-iteration CUPTI timestamp loop to ensure ordering before timestamp capture.
Utilities
flashinfer/utils.py
Add cached helper get_shared_bytes_per_block_optin(device: torch.device) decorated with functools.cache.
Tests
tests/utils/test_topk.py
Tighten compute_topk_accuracy cardinality assertion, lower min accuracy threshold (0.98→0.97), and add SM100/SM103-gated parametrized correctness tests for clustered kernels (direct/page-table/ragged) including shape/dtype/value, per-row bounds, and accuracy checks.

Sequence Diagram(s)

sequenceDiagram
    participant Host as Host Code / User
    participant PyAPI as Python Wrapper\n(topk_clusters_* / top_k)
    participant TVMBind as TVM FFI Binding
    participant Kernel as CUDA Kernel\nfast_topk_cuda_v4
    participant Overflow as Global Overflow Cache

    Host->>PyAPI: call logits, top_k, options
    PyAPI->>PyAPI: allocate indices, values?, cached_overflow
    PyAPI->>TVMBind: call fast_topk_clusters_exact*(..., cached_overflow, ...)
    TVMBind->>TVMBind: validate shapes/dtypes/strides, get stream
    TVMBind->>Kernel: launch specialized kernel on stream

    rect rgba(100,150,200,0.5)
        Kernel->>Kernel: per-block histograms\ncompute threshold_bin
    end

    rect rgba(150,100,200,0.5)
        Kernel->>Kernel: emit >threshold, cache/spill equals\niterate refinement rounds
        Kernel->>Overflow: spill overflow candidates
        Overflow-->>Kernel: supply spilled candidates
    end

    Kernel->>TVMBind: write final indices (and values)
    TVMBind->>PyAPI: return tensors
    PyAPI->>Host: return results
Loading
sequenceDiagram
    participant Env as Environment
    participant API as top_k()/top_k_*_transform()
    participant Selector as Algorithm Selector
    participant Clusters as Clustered Path
    participant Radix as Radix Path
    participant Host as Return Results

    Env->>API: FLASHINFER_TOPK_ALGO
    API->>Selector: check env var, device, deterministic

    alt select clusters
        Selector->>Clusters: call clustered wrapper
        Clusters->>Host: return indices (+ values, sorted if requested)
    else
        Selector->>Radix: call existing radix multi-CTA path
        Radix->>Host: return indices (+ values)
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • yzh119
  • kahyunnam
  • cyx-6
  • jimmyzho
  • bkryu
  • aleozlx
  • nv-yunzheq
  • jiahanc

Poem

🐰 In shared‑mem meadows histograms rise,

clusters hop thresholds, chasing the prize.
Overflows tumble, then neatly align,
indices hop home in cooperative time.
Benchmarks clap—kernels hum, all is fine.

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 39.47% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title '[feat] Faster topk algorithm' clearly and concisely describes the main addition: a new, faster top-k algorithm implementation.
Description check ✅ Passed The PR description provides a clear overview of the implementation (SM90+ clusters, non-deterministic with overflow to global memory), includes extensive benchmark results comparing against torch.topk and FlashInfer defaults, and confirms pre-commit checks and tests were completed.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new cluster-based top-k algorithm optimized for Blackwell GPUs (SM 100/103), featuring standard, page table, and ragged transform implementations. The changes include high-performance CUDA kernels using cooperative groups, TVM FFI bindings, and Python API integration with an environment variable toggle for algorithm selection. Feedback focuses on improving code quality by addressing const correctness in the C++ bindings, replacing magic numbers with descriptive constants in the Python wrappers, and enhancing type safety in shared memory calculations.

Comment thread csrc/flashinfer_fast_topk_clusters_binding.cu Outdated
Comment thread flashinfer/topk.py Outdated
Comment thread include/flashinfer/fast_topk_clusters_exact.cuh Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
benchmarks/bench_topk.py (1)

92-130: ⚠️ Potential issue | 🟠 Major

Don’t overwrite the benchmark mode that the caller selected.

Lines 92, 359, and 432 force FLASHINFER_TOPK_ALGO="default" before the baseline timing. That breaks --compare-algorithms: the outer loop sets "multi_cta"/"filtered" right before calling these helpers, but both paths get benchmarked as the same "default" baseline. The extra "clusters" run also is not exception-safe, so a failure there leaves later cases with the wrong env.

Example approach
+@contextmanager
+def temporary_topk_algo(algo: str | None):
+    previous = os.environ.get("FLASHINFER_TOPK_ALGO")
+    try:
+        if algo is None or algo == "auto":
+            os.environ.pop("FLASHINFER_TOPK_ALGO", None)
+        else:
+            os.environ["FLASHINFER_TOPK_ALGO"] = algo
+        yield
+    finally:
+        if previous is None:
+            os.environ.pop("FLASHINFER_TOPK_ALGO", None)
+        else:
+            os.environ["FLASHINFER_TOPK_ALGO"] = previous
-    set_topk_algo("default")
     fi_ms, fi_nondeterministic_ms = bench_flashinfer_modes(...)
     ...
-    set_topk_algo("clusters")
-    fast_topk_ms = bench_median_ms(lambda: flashinfer.top_k(scores, k))
-    set_topk_algo("auto")
+    with temporary_topk_algo("clusters"):
+        fast_topk_ms = bench_median_ms(lambda: flashinfer.top_k(scores, k))

Also applies to: 359-398, 432-469

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_topk.py` around lines 92 - 130, The benchmark currently
overwrites the caller's selected FLASHINFER_TOPK_ALGO by calling
set_topk_algo("default") and later
set_topk_algo("clusters")/set_topk_algo("auto"); update bench_topk.py so it
preserves and restores the prior top-k algorithm instead of forcing "default":
capture the current algo before any set_topk_algo calls, run
bench_flashinfer_modes/bench_median_ms as intended, and restore the original
algo in a finally block (or remove the initial set_topk_algo("default") entirely
if unnecessary) to make the code exception-safe; reference functions/identifiers
to change: set_topk_algo, bench_flashinfer_modes, flashinfer.top_k,
bench_median_ms, and the blocks that set "clusters" and "auto".
🧹 Nitpick comments (1)
tests/utils/test_topk.py (1)

2303-2327: Strengthen the exact-path assertions.

topk_clusters_exact is the correctness-preserving path, but this still passes if a row contains a few wrong selections as long as overlap stays above the threshold and the per-row min/max happen to match. Comparing values against gather(logits, indices) and checking the k-th-value threshold row-wise is a much tighter signal without depending on tie ordering.

Example tightening
     if output_values:
         assert values is not None
         assert values.shape == (batch_size, k)
         assert values.dtype == dtype

         abs_err = 0.125 if dtype == torch.bfloat16 else 1e-5
         rel_err = 0.1 if dtype == torch.bfloat16 else 1e-5
-        torch.testing.assert_close(
-            values.min(dim=-1).values,
-            ref_values.min(dim=-1).values,
-            rtol=rel_err,
-            atol=abs_err,
-        )
-        torch.testing.assert_close(
-            values.max(dim=-1).values,
-            ref_values.max(dim=-1).values,
-            rtol=rel_err,
-            atol=abs_err,
-        )
+        gathered_values = torch.gather(logits, dim=-1, index=indices.long())
+        torch.testing.assert_close(
+            values,
+            gathered_values,
+            rtol=rel_err,
+            atol=abs_err,
+        )
+        assert verify_topk_correctness(logits, values, indices.long(), k)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/utils/test_topk.py` around lines 2303 - 2327, When output_values is
True, strengthen exact-path assertions by verifying that returned values exactly
match the selected logits and that each row meets the k-th-value threshold: 1)
compute gathered = torch.gather(logits, -1, indices) and assert
values.shape/dtype and torch.testing.assert_close(values, gathered) to ensure
the exact selected entries match; 2) compute kth = torch.topk(logits,
k).values[:, -1] (or equivalently gather the k-th threshold via ref_indices) and
assert every value in values per row is >= kth (row-wise) to ensure no value
below the k-th threshold was selected; keep the existing accuracy check
(compute_topk_accuracy(indices, ref_indices.int(), ...)) afterward.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@csrc/flashinfer_fast_topk_clusters_binding.cu`:
- Around line 56-73: The dispatch macro DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16
excludes bfloat16 so torch.bfloat16 inputs never reach the launchers; update the
dispatch to include BF16 (or add a separate branch) so the template calls to
launch_fast_topk_clusters_exact (and similar callers at the other spots) are
instantiated with nv_bfloat16 / OrderedBits<nv_bfloat16>. Replace
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16 with a macro that also maps the DLPack
dtype for bfloat16 to the correct c_type (nv_bfloat16) or add an explicit
if-case for torch.bfloat16 that calls
launch_fast_topk_clusters_exact<c_type=nv_bfloat16,...> (keeping idx_int64/int
branches and the same argument casts) so BF16 paths reach the kernel.

In `@flashinfer/topk.py`:
- Around line 309-414: The clustered fast_topk kernels are being used
unconditionally; guard topk_clusters_exact, topk_clusters_page_table_transform,
and topk_clusters_ragged_transform (and the other similar helpers noted) behind
the backend/capability gate so we fall back to the radix path unless the device
explicitly supports the clustered backend. Concretely, annotate these APIs with
the `@backend_requirement` decorator and use the provided is_backend_supported() /
is_compute_capability_supported(cc) checks (or the existing helper that checks
FLASHINFER_TOPK_ALGO) to choose the clustered code path only when supported;
otherwise call the original radix implementation (preserve the previous return
types) and do not allocate the large overflow buffers or call fast_topk_* on
unsupported devices. Ensure the same gating is applied to the transform helpers
as well.

In `@include/flashinfer/fast_topk_clusters_exact.cuh`:
- Around line 477-485: In the fast path where seq_len <= TopK you only
initialize output_indices; also write output_values for each slot: inside the
same loop that writes output_indices (use ind_offset and i), set
output_values[ind_offset + i] to the corresponding input value for i < seq_len
(use the same source values array used elsewhere in this function, e.g., values
or input_values) and set output_values[ind_offset + i] to a sentinel for empty
slots (e.g., -INFINITY or numeric_limits<ValueT>::lowest()) when i >= seq_len so
the values tensor is always initialized.

---

Outside diff comments:
In `@benchmarks/bench_topk.py`:
- Around line 92-130: The benchmark currently overwrites the caller's selected
FLASHINFER_TOPK_ALGO by calling set_topk_algo("default") and later
set_topk_algo("clusters")/set_topk_algo("auto"); update bench_topk.py so it
preserves and restores the prior top-k algorithm instead of forcing "default":
capture the current algo before any set_topk_algo calls, run
bench_flashinfer_modes/bench_median_ms as intended, and restore the original
algo in a finally block (or remove the initial set_topk_algo("default") entirely
if unnecessary) to make the code exception-safe; reference functions/identifiers
to change: set_topk_algo, bench_flashinfer_modes, flashinfer.top_k,
bench_median_ms, and the blocks that set "clusters" and "auto".

---

Nitpick comments:
In `@tests/utils/test_topk.py`:
- Around line 2303-2327: When output_values is True, strengthen exact-path
assertions by verifying that returned values exactly match the selected logits
and that each row meets the k-th-value threshold: 1) compute gathered =
torch.gather(logits, -1, indices) and assert values.shape/dtype and
torch.testing.assert_close(values, gathered) to ensure the exact selected
entries match; 2) compute kth = torch.topk(logits, k).values[:, -1] (or
equivalently gather the k-th threshold via ref_indices) and assert every value
in values per row is >= kth (row-wise) to ensure no value below the k-th
threshold was selected; keep the existing accuracy check
(compute_topk_accuracy(indices, ref_indices.int(), ...)) afterward.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: e69861f6-01ca-4756-9445-13bbe31545f5

📥 Commits

Reviewing files that changed from the base of the PR and between edcef4b and 849bd57.

📒 Files selected for processing (7)
  • benchmarks/bench_topk.py
  • csrc/flashinfer_fast_topk_clusters_binding.cu
  • flashinfer/jit/topk.py
  • flashinfer/testing/utils.py
  • flashinfer/topk.py
  • include/flashinfer/fast_topk_clusters_exact.cuh
  • tests/utils/test_topk.py

Comment thread csrc/flashinfer_fast_topk_clusters_binding.cu
Comment thread flashinfer/topk.py
Comment thread include/flashinfer/fast_topk_clusters_exact.cuh
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
csrc/flashinfer_fast_topk_clusters_binding.cu (1)

56-71: ⚠️ Potential issue | 🔴 Critical

BF16 still never reaches the clustered launchers.

All three entrypoints still dispatch through DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16, so torch.bfloat16 inputs are rejected before launch. top_k, top_k_page_table_transform, and top_k_ragged_transform now route non-deterministic calls here by default, so this is a user-visible regression for advertised BF16 inputs. Please add a BF16-capable dispatch (or an explicit nv_bfloat16 branch) at each site.

Also applies to: 103-111, 143-151

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/flashinfer_fast_topk_clusters_binding.cu` around lines 56 - 71, The
dispatch macro DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16 currently blocks
torch.bfloat16 inputs; update the dispatch at this call site (wrapping
launch_fast_topk_clusters_exact) to include BF16 support by either using/adding
a DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16_BF16 (or equivalent) macro or adding
an explicit nv_bfloat16 branch that calls
launch_fast_topk_clusters_exact<c_type, ...> with c_type = nv_bfloat16; make the
same change for the other two entrypoints that route here (the top_k,
top_k_page_table_transform, and top_k_ragged_transform call sites) so BF16
tensors reach the clustered launchers instead of being rejected.
🧹 Nitpick comments (1)
flashinfer/topk.py (1)

346-382: Wrap the new clustered helpers with @flashinfer_api.

topk_clusters_exact, topk_clusters_page_table_transform, and topk_clusters_ragged_transform are top-level helpers in this module, but unlike the existing public top-k APIs they currently bypass the standard logging wrapper.

🧩 Suggested fix
+@flashinfer_api
 def topk_clusters_exact(
     logits, top_k, output_values=False, out_dtype=torch.int32, pdl=False
 ):
     ...

+@flashinfer_api
 def topk_clusters_page_table_transform(
     logits, seq_lens, src_page_table, top_k, pdl=False
 ):
     ...

+@flashinfer_api
 def topk_clusters_ragged_transform(logits, seq_lens, offsets, top_k, pdl=False):
     ...
As per coding guidelines: Enable API logging in production debugging using `flashinfer_api` decorator and environment variables: FLASHINFER_LOGLEVEL (0/1/3/5) and FLASHINFER_LOGDEST.

Also applies to: 385-413, 416-442

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/topk.py` around lines 346 - 382, The three new top-level helpers
(topk_clusters_exact, topk_clusters_page_table_transform,
topk_clusters_ragged_transform) must be wrapped with the flashinfer_api
decorator to enable standard logging; add `@flashinfer_api` directly above each
def while preserving their signatures, and ensure flashinfer_api is
imported/available in this module if not already. Place the decorator
immediately above the function definitions (no other changes to arguments/return
types), so the functions use the FLASHINFER_LOGLEVEL/FLASHINFER_LOGDEST behavior
used by the existing public top-k APIs.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/topk.py`:
- Around line 357-363: The per-cluster overflow capacity uses floor division and
must be rounded up to avoid under-allocating when max_model_len % num_clusters
!= 0; replace occurrences where topk_global_overflow = max_model_len //
num_clusters with a ceiling division (e.g., (max_model_len + num_clusters - 1)
// num_clusters) and then allocate overflow_buf (and the analogous buffers in
the other two clustered helpers) using that rounded-up topk_global_overflow so
the kernel’s per-cluster overflow_stride never overruns; update all three
clustered helper sites that define topk_global_overflow and allocate
overflow_buf accordingly.

---

Duplicate comments:
In `@csrc/flashinfer_fast_topk_clusters_binding.cu`:
- Around line 56-71: The dispatch macro DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16
currently blocks torch.bfloat16 inputs; update the dispatch at this call site
(wrapping launch_fast_topk_clusters_exact) to include BF16 support by either
using/adding a DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16_BF16 (or equivalent)
macro or adding an explicit nv_bfloat16 branch that calls
launch_fast_topk_clusters_exact<c_type, ...> with c_type = nv_bfloat16; make the
same change for the other two entrypoints that route here (the top_k,
top_k_page_table_transform, and top_k_ragged_transform call sites) so BF16
tensors reach the clustered launchers instead of being rejected.

---

Nitpick comments:
In `@flashinfer/topk.py`:
- Around line 346-382: The three new top-level helpers (topk_clusters_exact,
topk_clusters_page_table_transform, topk_clusters_ragged_transform) must be
wrapped with the flashinfer_api decorator to enable standard logging; add
`@flashinfer_api` directly above each def while preserving their signatures, and
ensure flashinfer_api is imported/available in this module if not already. Place
the decorator immediately above the function definitions (no other changes to
arguments/return types), so the functions use the
FLASHINFER_LOGLEVEL/FLASHINFER_LOGDEST behavior used by the existing public
top-k APIs.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: ba20616d-607b-4a87-9e1b-f1b6aace201c

📥 Commits

Reviewing files that changed from the base of the PR and between 849bd57 and 8014c33.

📒 Files selected for processing (5)
  • benchmarks/bench_topk.py
  • csrc/flashinfer_fast_topk_clusters_binding.cu
  • flashinfer/topk.py
  • flashinfer/utils.py
  • include/flashinfer/fast_topk_clusters_exact.cuh
✅ Files skipped from review due to trivial changes (1)
  • include/flashinfer/fast_topk_clusters_exact.cuh
🚧 Files skipped from review as they are similar to previous changes (1)
  • benchmarks/bench_topk.py

Comment thread flashinfer/topk.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
flashinfer/topk.py (1)

315-316: Remove unused function.

roundup_kbyte is defined but never called anywhere in this file. Consider removing it or documenting its intended future use.

🧹 Proposed fix
-def roundup_kbyte(x):
-    return (x + 1023) // 1024 * 1024
-
-
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/topk.py` around lines 315 - 316, The function roundup_kbyte is
unused in this module; remove the unused function definition roundup_kbyte from
flashinfer/topk.py to eliminate dead code, or if it is intended for future use
add a clear docstring and a unit test or TODO comment referencing its intended
caller (e.g., any functions that need kilobyte alignment) so its presence is
justified—prefer removing it unless you add documentation/tests to show it's
required.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@flashinfer/topk.py`:
- Around line 315-316: The function roundup_kbyte is unused in this module;
remove the unused function definition roundup_kbyte from flashinfer/topk.py to
eliminate dead code, or if it is intended for future use add a clear docstring
and a unit test or TODO comment referencing its intended caller (e.g., any
functions that need kilobyte alignment) so its presence is justified—prefer
removing it unless you add documentation/tests to show it's required.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: f6fafc55-d4ed-4d8f-ba26-0f88d48067b8

📥 Commits

Reviewing files that changed from the base of the PR and between 8014c33 and 42da8a7.

📒 Files selected for processing (1)
  • flashinfer/topk.py

@Aalanli
Copy link
Copy Markdown
Contributor Author

Aalanli commented Apr 8, 2026

Here are benchmarks for fp16:

> python benchmarks/bench_topk.py --op=top_k --input-pattern=random --dtype=fp16
====================================================================================================
top_k: Basic radix-based top-k selection (dtype=FP16, deterministic=False, pattern=random)
NOTE: default top-k sweep includes two extra large-batch/long-vocab stress cases beyond the original grid
====================================================================================================
 batch    seq_len      k |   FlashInfer   torch.topk    Speedup     Clusters  Speedup Clusters vs. Default
-------------------------------------------------------------------------------------------------------------------
     1        256    256 |         4.37us      14.18us      3.25x       1.89us                         2.31x
     1        512    256 |         7.58us      15.78us      2.08x       4.29us                         1.77x
     1        512    512 |         5.31us      15.49us      2.92x       1.95us                         2.72x
     1       1024    256 |         8.22us      19.65us      2.39x       4.61us                         1.78x
     1       1024    512 |         6.85us      17.54us      2.56x       4.32us                         1.58x
     1       1024   1024 |         5.09us      17.73us      3.48x       2.11us                         2.41x
     1       2048    256 |         8.10us      24.03us      2.97x       4.80us                         1.69x
     1       2048    512 |         9.06us      23.36us      2.58x       4.83us                         1.87x
     1       2048   1024 |         6.91us      21.87us      3.16x       4.64us                         1.49x
     1       2048   2048 |         5.81us      18.82us      3.24x       2.69us                         2.16x
     1       4096    256 |         8.13us      27.14us      3.34x       5.02us                         1.62x
     1       4096    512 |         9.34us      29.89us      3.20x       5.41us                         1.73x
     1       4096   1024 |         9.28us      29.79us      3.21x       5.38us                         1.73x
     1       4096   2048 |         7.81us      25.54us      3.27x       5.12us                         1.52x
     1       4096   4096 |         6.91us      27.58us      3.99x       2.56us                         2.70x
     1      16384    256 |        10.40us      64.32us      6.18x       7.97us                         1.31x
     1      16384    512 |        11.78us      69.82us      5.93x       8.18us                         1.44x
     1      16384   1024 |        12.35us      68.27us      5.53x       8.13us                         1.52x
     1      16384   2048 |        14.11us      68.42us      4.85x       8.51us                         1.66x
     1      16384   4096 |        16.45us      66.50us      4.04x       8.67us                         1.90x
     1      65536    256 |        29.50us      61.65us      2.09x       8.77us                         3.36x
     1      65536    512 |        31.57us      58.88us      1.87x       9.47us                         3.33x
     1      65536   1024 |        34.37us      62.67us      1.82x       9.34us                         3.68x
     1      65536   2048 |        36.08us      61.98us      1.72x       9.70us                         3.72x
     1      65536   4096 |        38.21us      62.72us      1.64x      10.08us                         3.79x
     1     131072    256 |        35.90us      63.68us      1.77x      10.98us                         3.27x
     1     131072    512 |        36.48us      65.28us      1.79x      10.91us                         3.34x
     1     131072   1024 |        38.66us      67.23us      1.74x      12.06us                         3.20x
     1     131072   2048 |        40.06us      64.86us      1.62x      11.92us                         3.36x
     1     131072   4096 |        42.67us      64.74us      1.52x      12.83us                         3.33x
     1     262144    256 |        42.21us      71.70us      1.70x      14.66us                         2.88x
     1     262144    512 |        43.78us      73.31us      1.67x      15.23us                         2.87x
     1     262144   1024 |        44.80us      75.14us      1.68x      15.46us                         2.90x
     1     262144   2048 |        47.17us      74.88us      1.59x      17.14us                         2.75x
     1     262144   4096 |        49.92us      73.92us      1.48x      17.44us                         2.86x
     1     524288    256 |        47.39us      89.22us      1.88x      22.18us                         2.14x
     1     524288    512 |        48.53us      88.13us      1.82x      22.26us                         2.18x
     1     524288   1024 |        49.60us      92.86us      1.87x      23.63us                         2.10x
     1     524288   2048 |        51.07us      94.30us      1.85x      23.78us                         2.15x
     1     524288   4096 |        52.90us      90.46us      1.71x      27.41us                         1.93x
    16        256    256 |         6.11us      15.26us      2.50x       1.98us                         3.08x
    16        512    256 |         9.09us      17.22us      1.89x       4.38us                         2.07x
    16        512    512 |         5.44us      16.90us      3.11x       2.08us                         2.62x
    16       1024    256 |         9.15us      20.86us      2.28x       4.90us                         1.87x
    16       1024    512 |         8.48us      20.38us      2.40x       4.64us                         1.83x
    16       1024   1024 |         6.40us      20.45us      3.19x       2.30us                         2.78x
    16       2048    256 |         9.39us      24.54us      2.61x       5.18us                         1.81x
    16       2048    512 |         9.06us      24.67us      2.72x       5.03us                         1.80x
    16       2048   1024 |         9.15us      24.61us      2.69x       4.93us                         1.86x
    16       2048   2048 |         7.07us      23.58us      3.33x       2.88us                         2.46x
    16       4096    256 |         9.81us      32.32us      3.30x       5.31us                         1.85x
    16       4096    512 |         9.41us      32.37us      3.44x       5.57us                         1.69x
    16       4096   1024 |         9.95us      32.80us      3.30x       5.66us                         1.76x
    16       4096   2048 |         9.73us      31.26us      3.21x       5.57us                         1.75x
    16       4096   4096 |         7.71us      31.23us      4.05x       2.66us                         2.90x
    16      16384    256 |        12.03us      76.38us      6.35x       9.70us                         1.24x
    16      16384    512 |        12.67us      73.86us      5.83x       9.98us                         1.27x
    16      16384   1024 |        13.41us      74.13us      5.53x      10.24us                         1.31x
    16      16384   2048 |        14.69us      76.99us      5.24x      10.85us                         1.35x
    16      16384   4096 |        17.15us      79.89us      4.66x      11.20us                         1.53x
    16      65536    256 |        31.33us      68.91us      2.20x      11.78us                         2.66x
    16      65536    512 |        32.83us      71.84us      2.19x      12.38us                         2.65x
    16      65536   1024 |        34.88us      71.42us      2.05x      12.51us                         2.79x
    16      65536   2048 |        36.48us      70.64us      1.94x      12.93us                         2.82x
    16      65536   4096 |        39.36us      72.29us      1.84x      13.76us                         2.86x
    16     131072    256 |        37.38us      79.01us      2.11x      14.98us                         2.50x
    16     131072    512 |        37.98us      80.64us      2.12x      14.83us                         2.56x
    16     131072   1024 |        39.71us      77.14us      1.94x      16.51us                         2.41x
    16     131072   2048 |        41.54us      80.66us      1.94x      16.70us                         2.49x
    16     131072   4096 |        43.39us      81.71us      1.88x      18.11us                         2.40x
    16     262144    256 |        44.32us      94.30us      2.13x      20.48us                         2.16x
    16     262144    512 |        45.15us      97.09us      2.15x      21.54us                         2.10x
    16     262144   1024 |        46.06us      95.23us      2.07x      21.60us                         2.13x
    16     262144   2048 |        48.48us      95.42us      1.97x      24.54us                         1.98x
    16     262144   4096 |        50.62us      98.69us      1.95x      25.23us                         2.01x
    16     524288    256 |        49.70us     125.28us      2.52x      32.19us                         1.54x
    16     524288    512 |        50.27us     126.99us      2.53x      32.38us                         1.55x
    16     524288   1024 |        51.39us     123.98us      2.41x      34.45us                         1.49x
    16     524288   2048 |        52.03us     128.42us      2.47x      34.62us                         1.50x
    16     524288   4096 |        54.51us     129.76us      2.38x      41.09us                         1.33x
    64        256    256 |         6.08us      15.81us      2.60x       2.14us                         2.84x
    64        512    256 |         9.41us      17.74us      1.89x       4.70us                         2.00x
    64        512    512 |         5.34us      17.47us      3.27x       2.24us                         2.39x
    64       1024    256 |         9.89us      21.76us      2.20x       5.06us                         1.96x
    64       1024    512 |         8.96us      21.34us      2.38x       4.83us                         1.85x
    64       1024   1024 |         6.62us      21.76us      3.28x       2.43us                         2.72x
    64       2048    256 |         9.76us      25.15us      2.58x       5.25us                         1.86x
    64       2048    512 |         8.96us      25.06us      2.80x       5.25us                         1.71x
    64       2048   1024 |         9.71us      25.63us      2.64x       5.06us                         1.92x
    64       2048   2048 |         7.30us      23.97us      3.29x       3.10us                         2.35x
    64       4096    256 |        10.18us      32.58us      3.20x       5.57us                         1.83x
    64       4096    512 |         9.34us      32.35us      3.46x       5.73us                         1.63x
    64       4096   1024 |        10.53us      33.50us      3.18x       5.86us                         1.80x
    64       4096   2048 |        10.37us      33.31us      3.21x       5.63us                         1.84x
    64       4096   4096 |         8.53us      35.46us      4.16x       3.01us                         2.84x
    64      16384    256 |        12.42us      68.46us      5.51x      10.18us                         1.22x
    64      16384    512 |        12.58us      67.46us      5.36x      10.56us                         1.19x
    64      16384   1024 |        14.27us      68.96us      4.83x      11.01us                         1.30x
    64      16384   2048 |        15.46us      67.71us      4.38x      11.26us                         1.37x
    64      16384   4096 |        17.26us      72.77us      4.22x      11.70us                         1.48x
    64      65536    256 |        32.18us      94.34us      2.93x      15.58us                         2.06x
    64      65536    512 |        33.70us      94.30us      2.80x      17.28us                         1.95x
    64      65536   1024 |        35.52us      94.85us      2.67x      17.34us                         2.05x
    64      65536   2048 |        37.41us      95.62us      2.56x      18.67us                         2.00x
    64      65536   4096 |        39.78us      99.95us      2.51x      20.54us                         1.94x
    64     131072    256 |        38.94us     120.32us      3.09x      22.30us                         1.75x
    64     131072    512 |        39.50us     125.06us      3.17x      22.40us                         1.76x
    64     131072   1024 |        40.86us     126.50us      3.10x      25.55us                         1.60x
    64     131072   2048 |        42.56us     126.34us      2.97x      26.05us                         1.63x
    64     131072   4096 |        44.67us     130.26us      2.92x      29.09us                         1.54x
    64     262144    256 |        83.06us     168.83us      2.03x      34.50us                         2.41x
    64     262144    512 |        84.35us     167.94us      1.99x      36.48us                         2.31x
    64     262144   1024 |        86.24us     171.57us      1.99x      36.77us                         2.35x
    64     262144   2048 |        90.08us     173.34us      1.92x      43.30us                         2.08x
    64     262144   4096 |        94.56us     177.74us      1.88x      44.62us                         2.12x
    64     524288    256 |       138.50us     275.77us      1.99x      62.22us                         2.23x
    64     524288    512 |       139.09us     277.17us      1.99x      62.24us                         2.23x
    64     524288   1024 |       141.60us     278.42us      1.97x      66.66us                         2.12x
    64     524288   2048 |       143.87us     280.34us      1.95x      67.55us                         2.13x
    64     524288   4096 |       152.50us     281.87us      1.85x      79.97us                         1.91x
   256        256    256 |         7.33us      17.12us      2.34x       2.43us                         3.01x
   256        512    256 |        13.76us      20.19us      1.47x       5.38us                         2.56x
   256        512    512 |         7.04us      19.78us      2.81x       2.66us                         2.65x
   256       1024    256 |        14.59us      31.94us      2.19x       5.98us                         2.44x
   256       1024    512 |        13.98us      29.98us      2.14x       5.73us                         2.44x
   256       1024   1024 |         7.68us      29.73us      3.87x       3.04us                         2.53x
   256       2048    256 |        15.26us      38.88us      2.55x       6.56us                         2.33x
   256       2048    512 |        15.07us      38.50us      2.55x       6.59us                         2.29x
   256       2048   1024 |        14.24us      37.25us      2.62x       6.69us                         2.13x
   256       2048   2048 |         8.96us      37.31us      4.16x       4.06us                         2.20x
   256       4096    256 |        15.65us      69.15us      4.42x       7.49us                         2.09x
   256       4096    512 |        15.84us      69.10us      4.36x       8.03us                         1.97x
   256       4096   1024 |        16.06us      70.50us      4.39x       8.38us                         1.92x
   256       4096   2048 |        15.78us      69.94us      4.43x       8.00us                         1.97x
   256       4096   4096 |        12.10us      76.61us      6.33x       4.38us                         2.76x
   256      16384    256 |        20.86us      96.30us      4.62x      14.06us                         1.48x
   256      16384    512 |        21.82us      95.54us      4.38x      15.39us                         1.42x
   256      16384   1024 |        23.87us      96.32us      4.03x      17.23us                         1.39x
   256      16384   2048 |        26.46us      97.66us      3.69x      19.04us                         1.39x
   256      16384   4096 |        30.40us     105.86us      3.48x      19.62us                         1.55x
   256      65536    256 |        59.36us     170.62us      2.87x      32.99us                         1.80x
   256      65536    512 |        63.14us     173.47us      2.75x      39.30us                         1.61x
   256      65536   1024 |        66.46us     175.07us      2.63x      40.34us                         1.65x
   256      65536   2048 |        70.75us     177.92us      2.51x      45.21us                         1.56x
   256      65536   4096 |        75.46us     187.70us      2.49x      52.83us                         1.43x
   256     131072    256 |       131.44us     279.50us      2.13x      61.92us                         2.12x
   256     131072    512 |       133.66us     279.49us      2.09x      62.26us                         2.15x
   256     131072   1024 |       140.35us     284.53us      2.03x      74.62us                         1.88x
   256     131072   2048 |       146.78us     286.75us      1.95x      76.69us                         1.91x
   256     131072   4096 |       155.62us     297.04us      1.91x      87.66us                         1.78x
   256     262144    256 |       235.68us     569.44us      2.42x     118.22us                         1.99x
   256     262144    512 |       240.83us     569.76us      2.37x     126.50us                         1.90x
   256     262144   1024 |       244.64us     568.51us      2.32x     128.93us                         1.90x
   256     262144   2048 |       256.86us     575.01us      2.24x     154.21us                         1.67x
   256     262144   4096 |       269.60us     583.66us      2.16x     165.41us                         1.63x
   256     524288    256 |       397.38us    1000.49us      2.52x     224.30us                         1.77x
   256     524288    512 |       400.06us     997.60us      2.49x     225.94us                         1.77x
   256     524288   1024 |       408.02us    1003.81us      2.46x     242.50us                         1.68x
   256     524288   2048 |       415.90us    1006.06us      2.42x     249.09us                         1.67x
   256     524288   4096 |       440.00us    1018.03us      2.31x     309.79us                         1.42x
  2048     131072   1024 |       932.05us    1937.69us      2.08x     462.10us                         2.02x
  4096     200000   1024 |      2442.41us    5674.84us      2.32x    1224.62us                         1.99x

@Aalanli
Copy link
Copy Markdown
Contributor Author

Aalanli commented Apr 8, 2026

And here's one for bf16:

> python benchmarks/bench_topk.py --op=top_k --input-pattern=random --dtype=bf16
====================================================================================================
top_k: Basic radix-based top-k selection (dtype=BF16, deterministic=False, pattern=random)
NOTE: default top-k sweep includes two extra large-batch/long-vocab stress cases beyond the original grid
====================================================================================================
 batch    seq_len      k |   FlashInfer   torch.topk    Speedup     Clusters  Speedup Clusters vs. Default
-------------------------------------------------------------------------------------------------------------------
     1        256    256 |         4.35us      14.94us      3.43x       1.89us                         2.31x
     1        512    256 |         7.71us      16.32us      2.12x       4.38us                         1.76x
     1        512    512 |         5.28us      15.94us      3.02x       1.95us                         2.70x
     1       1024    256 |         8.45us      20.16us      2.39x       4.61us                         1.83x
     1       1024    512 |         6.88us      17.54us      2.55x       4.32us                         1.59x
     1       1024   1024 |         5.10us      19.97us      3.91x       2.08us                         2.45x
     1       2048    256 |         8.02us      22.85us      2.85x       5.02us                         1.60x
     1       2048    512 |         8.99us      23.74us      2.64x       4.80us                         1.87x
     1       2048   1024 |         8.91us      22.72us      2.55x       4.99us                         1.79x
     1       2048   2048 |         5.86us      21.50us      3.67x       2.69us                         2.18x
     1       4096    256 |         8.45us      28.22us      3.34x       5.66us                         1.49x
     1       4096    512 |         9.49us      29.92us      3.15x       5.73us                         1.66x
     1       4096   1024 |         9.73us      31.07us      3.19x       5.89us                         1.65x
     1       4096   2048 |         9.63us      29.87us      3.10x       5.31us                         1.81x
     1       4096   4096 |         6.91us      28.54us      4.13x       2.56us                         2.70x
     1      16384    256 |        10.40us      64.42us      6.19x       7.90us                         1.32x
     1      16384    512 |        13.70us      64.48us      4.71x       8.58us                         1.60x
     1      16384   1024 |        14.27us      67.01us      4.70x       8.70us                         1.64x
     1      16384   2048 |        14.88us      71.38us      4.80x       8.83us                         1.68x
     1      16384   4096 |        16.66us      65.66us      3.94x       8.98us                         1.86x
     1      65536    256 |        30.37us      61.60us      2.03x       9.28us                         3.27x
     1      65536    512 |        32.16us      58.51us      1.82x       9.44us                         3.41x
     1      65536   1024 |        34.59us      62.83us      1.82x       9.28us                         3.73x
     1      65536   2048 |        37.47us      61.74us      1.65x      10.67us                         3.51x
     1      65536   4096 |        39.55us      62.21us      1.57x      10.85us                         3.65x
     1     131072    256 |        36.64us      63.94us      1.74x      11.92us                         3.07x
     1     131072    512 |        37.58us      65.15us      1.73x      11.94us                         3.15x
     1     131072   1024 |        39.10us      66.96us      1.71x      11.97us                         3.27x
     1     131072   2048 |        41.57us      64.62us      1.55x      11.97us                         3.47x
     1     131072   4096 |        44.61us      64.61us      1.45x      14.78us                         3.02x
     1     262144    256 |        44.38us      71.65us      1.61x      16.83us                         2.64x
     1     262144    512 |        44.99us      73.33us      1.63x      16.80us                         2.68x
     1     262144   1024 |        46.42us      75.10us      1.62x      16.93us                         2.74x
     1     262144   2048 |        47.97us      74.72us      1.56x      17.09us                         2.81x
     1     262144   4096 |        51.36us      73.38us      1.43x      17.17us                         2.99x
     1     524288    256 |        49.49us      89.52us      1.81x      26.34us                         1.88x
     1     524288    512 |        51.01us      88.16us      1.73x      26.53us                         1.92x
     1     524288   1024 |        52.03us      92.19us      1.77x      26.90us                         1.93x
     1     524288   2048 |        53.06us      94.56us      1.78x      26.98us                         1.97x
     1     524288   4096 |        54.14us      90.18us      1.67x      27.30us                         1.98x
    16        256    256 |         6.11us      16.26us      2.66x       1.98us                         3.08x
    16        512    256 |         9.50us      17.39us      1.83x       4.64us                         2.05x
    16        512    512 |         5.49us      17.38us      3.17x       2.11us                         2.60x
    16       1024    256 |         9.23us      20.74us      2.25x       4.93us                         1.87x
    16       1024    512 |         8.96us      20.86us      2.33x       4.86us                         1.84x
    16       1024   1024 |         6.37us      20.99us      3.30x       2.27us                         2.80x
    16       2048    256 |         9.54us      24.19us      2.54x       5.25us                         1.82x
    16       2048    512 |         9.12us      23.78us      2.61x       5.22us                         1.75x
    16       2048   1024 |         9.54us      25.22us      2.64x       5.17us                         1.85x
    16       2048   2048 |         7.10us      24.45us      3.44x       2.85us                         2.49x
    16       4096    256 |        10.27us      31.10us      3.03x       5.81us                         1.77x
    16       4096    512 |         9.62us      31.01us      3.22x       5.92us                         1.62x
    16       4096   1024 |        10.18us      31.42us      3.09x       6.08us                         1.67x
    16       4096   2048 |        10.27us      30.85us      3.00x       5.89us                         1.74x
    16       4096   4096 |         7.73us      34.43us      4.46x       2.67us                         2.89x
    16      16384    256 |        12.19us      71.81us      5.89x       9.60us                         1.27x
    16      16384    512 |        14.59us      70.56us      4.84x      11.02us                         1.32x
    16      16384   1024 |        15.07us      72.46us      4.81x      11.28us                         1.34x
    16      16384   2048 |        15.39us      73.02us      4.74x      11.42us                         1.35x
    16      16384   4096 |        17.39us      74.82us      4.30x      11.65us                         1.49x
    16      65536    256 |        32.35us      68.83us      2.13x      12.32us                         2.63x
    16      65536    512 |        33.15us      72.05us      2.17x      12.35us                         2.68x
    16      65536   1024 |        35.20us      70.98us      2.02x      12.45us                         2.83x
    16      65536   2048 |        38.02us      70.50us      1.85x      14.69us                         2.59x
    16      65536   4096 |        40.51us      71.84us      1.77x      15.30us                         2.65x
    16     131072    256 |        38.66us      79.44us      2.06x      16.22us                         2.38x
    16     131072    512 |        39.30us      80.77us      2.06x      16.45us                         2.39x
    16     131072   1024 |        40.70us      77.36us      1.90x      16.38us                         2.48x
    16     131072   2048 |        42.42us      80.85us      1.91x      16.66us                         2.55x
    16     131072   4096 |        45.01us      81.63us      1.81x      21.79us                         2.07x
    16     262144    256 |        46.18us      94.70us      2.05x      24.42us                         1.89x
    16     262144    512 |        46.46us      97.57us      2.10x      24.42us                         1.90x
    16     262144   1024 |        47.57us      95.36us      2.00x      24.42us                         1.95x
    16     262144   2048 |        48.90us      96.13us      1.97x      24.70us                         1.98x
    16     262144   4096 |        51.81us      98.26us      1.90x      25.22us                         2.05x
    16     524288    256 |        52.02us     127.09us      2.44x      39.90us                         1.30x
    16     524288    512 |        52.19us     128.83us      2.47x      39.87us                         1.31x
    16     524288   1024 |        53.06us     125.06us      2.36x      40.03us                         1.33x
    16     524288   2048 |        54.02us     129.90us      2.40x      40.38us                         1.34x
    16     524288   4096 |        55.71us     129.74us      2.33x      41.33us                         1.35x
    64        256    256 |         6.08us      16.19us      2.66x       2.14us                         2.84x
    64        512    256 |         9.57us      17.92us      1.87x       4.96us                         1.93x
    64        512    512 |         5.31us      18.02us      3.39x       2.27us                         2.34x
    64       1024    256 |         9.86us      21.50us      2.18x       5.06us                         1.95x
    64       1024    512 |         9.31us      21.41us      2.30x       4.99us                         1.87x
    64       1024   1024 |         6.62us      21.54us      3.25x       2.42us                         2.74x
    64       2048    256 |         9.73us      24.38us      2.51x       5.28us                         1.84x
    64       2048    512 |         9.02us      24.42us      2.71x       5.34us                         1.69x
    64       2048   1024 |        10.18us      25.70us      2.53x       5.34us                         1.90x
    64       2048   2048 |         7.33us      25.60us      3.49x       3.14us                         2.34x
    64       4096    256 |        10.58us      31.30us      2.96x       6.03us                         1.75x
    64       4096    512 |         9.60us      31.14us      3.24x       6.08us                         1.58x
    64       4096   1024 |        10.85us      32.13us      2.96x       6.24us                         1.74x
    64       4096   2048 |        10.75us      32.77us      3.05x       5.89us                         1.83x
    64       4096   4096 |         8.51us      34.11us      4.01x       2.98us                         2.86x
    64      16384    256 |        12.26us      68.70us      5.61x      10.30us                         1.19x
    64      16384    512 |        14.43us      67.70us      4.69x      11.30us                         1.28x
    64      16384   1024 |        15.78us      68.69us      4.35x      11.46us                         1.38x
    64      16384   2048 |        16.24us      67.87us      4.18x      11.55us                         1.41x
    64      16384   4096 |        17.54us      72.93us      4.16x      12.10us                         1.45x
    64      65536    256 |        33.31us      94.59us      2.84x      17.06us                         1.95x
    64      65536    512 |        34.11us      94.62us      2.77x      17.18us                         1.99x
    64      65536   1024 |        35.90us      95.30us      2.65x      17.54us                         2.05x
    64      65536   2048 |        38.98us      96.62us      2.48x      22.85us                         1.71x
    64      65536   4096 |        40.94us     100.45us      2.45x      24.61us                         1.66x
    64     131072    256 |        40.13us     123.12us      3.07x      25.12us                         1.60x
    64     131072    512 |        40.67us     124.96us      3.07x      25.22us                         1.61x
    64     131072   1024 |        41.57us     125.60us      3.02x      25.68us                         1.62x
    64     131072   2048 |        43.87us     126.67us      2.89x      26.16us                         1.68x
    64     131072   4096 |        46.67us     133.01us      2.85x      39.74us                         1.17x
    64     262144    256 |        87.01us     171.55us      1.97x      42.06us                         2.07x
    64     262144    512 |        87.30us     170.58us      1.95x      42.18us                         2.07x
    64     262144   1024 |        89.49us     173.74us      1.94x      42.66us                         2.10x
    64     262144   2048 |        91.87us     174.05us      1.89x      43.42us                         2.12x
    64     262144   4096 |        96.69us     178.51us      1.85x      44.90us                         2.15x
    64     524288    256 |       144.83us     279.15us      1.93x      77.42us                         1.87x
    64     524288    512 |       145.50us     279.78us      1.92x      77.62us                         1.87x
    64     524288   1024 |       146.70us     282.51us      1.93x      78.30us                         1.87x
    64     524288   2048 |       149.79us     284.13us      1.90x      78.64us                         1.90x
    64     524288   4096 |       154.78us     285.69us      1.85x      80.29us                         1.93x
   256        256    256 |         7.30us      17.57us      2.41x       2.43us                         3.00x
   256        512    256 |        14.56us      20.99us      1.44x       5.70us                         2.56x
   256        512    512 |         7.04us      20.32us      2.89x       2.66us                         2.65x
   256       1024    256 |        14.56us      31.81us      2.18x       6.02us                         2.42x
   256       1024    512 |        14.85us      31.39us      2.11x       5.92us                         2.51x
   256       1024   1024 |         7.71us      31.20us      4.05x       3.17us                         2.43x
   256       2048    256 |        15.20us      38.14us      2.51x       6.66us                         2.28x
   256       2048    512 |        15.04us      37.97us      2.52x       6.81us                         2.21x
   256       2048   1024 |        14.94us      38.94us      2.61x       6.93us                         2.16x
   256       2048   2048 |         8.99us      39.18us      4.36x       4.06us                         2.21x
   256       4096    256 |        16.48us      69.28us      4.20x       8.26us                         2.00x
   256       4096    512 |        16.22us      68.88us      4.25x       8.58us                         1.89x
   256       4096   1024 |        16.58us      70.58us      4.26x       8.96us                         1.85x
   256       4096   2048 |        16.67us      69.76us      4.18x       8.48us                         1.97x
   256       4096   4096 |        12.13us      75.87us      6.26x       4.38us                         2.77x
   256      16384    256 |        20.75us      96.25us      4.64x      14.11us                         1.47x
   256      16384    512 |        25.95us      96.32us      3.71x      19.14us                         1.36x
   256      16384   1024 |        27.18us      96.94us      3.57x      20.22us                         1.34x
   256      16384   2048 |        27.98us      97.66us      3.49x      20.86us                         1.34x
   256      16384   4096 |        30.85us     106.30us      3.45x      21.57us                         1.43x
   256      65536    256 |        61.52us     172.78us      2.81x      38.82us                         1.58x
   256      65536    512 |        63.79us     174.24us      2.73x      39.44us                         1.62x
   256      65536   1024 |        67.02us     175.39us      2.62x      40.35us                         1.66x
   256      65536   2048 |        73.47us     182.03us      2.48x      69.26us                         1.06x
   256      65536   4096 |        77.86us     190.11us      2.44x      75.82us                         1.03x
   256     131072    256 |       136.22us     280.98us      2.06x      72.93us                         1.87x
   256     131072    512 |       138.54us     281.20us      2.03x      73.50us                         1.88x
   256     131072   1024 |       142.75us     283.73us      1.99x      74.94us                         1.90x
   256     131072   2048 |       150.34us     287.01us      1.91x      76.51us                         1.96x
   256     131072   4096 |       163.18us     304.90us      1.87x     160.05us                         1.02x
   256     262144    256 |       246.88us     568.75us      2.30x     146.50us                         1.69x
   256     262144    512 |       249.53us     569.66us      2.28x     148.58us                         1.68x
   256     262144   1024 |       253.86us     572.70us      2.26x     151.09us                         1.68x
   256     262144   2048 |       261.57us     575.71us      2.20x     154.61us                         1.69x
   256     262144   4096 |       275.90us     584.85us      2.12x     166.34us                         1.66x
   256     524288    256 |       417.21us     998.33us      2.39x     291.28us                         1.43x
   256     524288    512 |       419.65us     993.21us      2.37x     294.18us                         1.43x
   256     524288   1024 |       424.29us    1007.95us      2.38x     298.75us                         1.42x
   256     524288   2048 |       432.67us    1008.69us      2.33x     306.13us                         1.41x
   256     524288   4096 |       447.47us    1019.81us      2.28x     321.22us                         1.39x
  2048     131072   1024 |       948.75us    1933.69us      2.04x     463.31us                         2.05x
  4096     200000   1024 |      2537.64us    5690.76us      2.24x    1424.65us                         1.78x

@Aalanli
Copy link
Copy Markdown
Contributor Author

Aalanli commented Apr 8, 2026

Note that all benchmarks are done on B200.

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Apr 8, 2026

Great work and I'll help review this one, using CTA cluster is definitely a good idea.

This is a non-deterministic algorithm, but does not drop indices and instead overflows to global memory.

Hi @Aalanli, can you clarify about non-deterministic here?

  • Does it mean the order of output elements is non-deterministic? More specifically, does it guarantee set(top_k_first_run) == set(top_k_second_run)? If so, I think non-determinism should be tolerable.
  • How does it deal with tie elements? Will different run select the same set of tie elements?

@bkryu bkryu added the run-ci label Apr 8, 2026
@Aalanli
Copy link
Copy Markdown
Contributor Author

Aalanli commented Apr 9, 2026

Hi @yzh119, the order of indices is not guaranteed to stay consistent across runs. If there are tie elements the set of indices is not guaranteed to be the same. In practice the tests assert greater than 99% overlap for f32 and f16 dtype; the set of values should be the same regardless.

The algorithm biases earlier elements if there are any tie elements, but due to atomics it's not guaranteed. I think this algorithm has the same properties as the filtered topk algorithm flashinfer already has, for the non-deterministic version.

The changes I made to the python API selects this algorithm by default if deterministic=False and sm_100+ is available. Otherwise if the user selects deterministic=True then we get what flashinfer already has.

};

template <typename T>
struct OrderedBits;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is duplicated with RadixTopKTraits in topk.cuh. Should we create a topk_common.cuh file to reuse the struct OrderedBits in both topk.cuh and fast_topk_clusters_exact.cuh?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that sounds good.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @jiangyinzuo , I moved the common stuff into topk_common.cuh.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@include/flashinfer/fast_topk_clusters_exact.cuh`:
- Around line 561-583: The helper launch_topk_cluster_kernel currently calls
cudaFuncSetAttribute and cudaLaunchKernelExC without checking their return
values; update launch_topk_cluster_kernel to capture and handle cudaError_t
results from cudaFuncSetAttribute (both calls) and from cudaLaunchKernelExC,
e.g., check the returned error, and on failure either assert in debug builds or
propagate/return the error (or log it) so failures on non‑SM90 hardware are
visible; reference the cudaFuncSetAttribute calls near the top of
launch_topk_cluster_kernel and the final cudaLaunchKernelExC call when adding
these checks.

In `@include/flashinfer/topk_common.cuh`:
- Around line 4-8: The header is missing CUDA headers required for types and
intrinsics used: add the appropriate CUDA includes so symbols like
cuda::std::numeric_limits<float>::infinity(), half and
__half_as_ushort/__ushort_as_half, and
nv_bfloat16/__bfloat16_as_ushort/__ushort_as_bfloat16 are defined; specifically
include <cuda/std/limits> for cuda::std::numeric_limits, <cuda_fp16.h> for half
and half intrinsics, and <cuda_bf16.h> for nv_bfloat16 and its intrinsics
(update include list near the top of topk_common.cuh where other std headers are
listed).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 6a813e72-279a-450e-bc99-0da55c5ee316

📥 Commits

Reviewing files that changed from the base of the PR and between 42da8a7 and a2e4e17.

📒 Files selected for processing (4)
  • csrc/flashinfer_fast_topk_clusters_binding.cu
  • include/flashinfer/fast_topk_clusters_exact.cuh
  • include/flashinfer/topk.cuh
  • include/flashinfer/topk_common.cuh
✅ Files skipped from review due to trivial changes (1)
  • csrc/flashinfer_fast_topk_clusters_binding.cu

Comment thread include/flashinfer/fast_topk_clusters_exact.cuh
Comment thread include/flashinfer/topk_common.cuh
@kahyunnam
Copy link
Copy Markdown
Member

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !546 has been created, and the CI pipeline #48511515 is currently running. I'll report back once the pipeline job completes.

@kahyunnam
Copy link
Copy Markdown
Member

kahyunnam commented Apr 14, 2026

@Aalanli the pre-commit format check is failing: https://github.com/flashinfer-ai/flashinfer/actions/runs/24349849138/job/71303983434?pr=3009

Could you please rerun pre-commit and push?

Also, could you either address or resolve the AI code review comments above?

@Aalanli
Copy link
Copy Markdown
Contributor Author

Aalanli commented Apr 14, 2026

Hi @kahyunnam thanks for taking a look. I addressed AI comments, and some build issues as well as pre-commit.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (3)
include/flashinfer/fast_topk_clusters_exact.cuh (2)

414-424: ⚠️ Potential issue | 🟡 Minor

Initialize padded output_values slots too.

The valid entries are fixed now, but when TopK > seq_len the else branch still leaves output_values[ind_offset + i] untouched while output_indices is set to -1. That returns garbage in the padded tail of the row.

💡 Minimal fix
       } else {
         output_indices[ind_offset + i] = static_cast<IdxT>(-1);
+        if (output_values != nullptr) {
+          output_values[ind_offset + i] = RadixTopKTraits<T>::NegInf();
+        }
       }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/fast_topk_clusters_exact.cuh` around lines 414 - 424, The
padded output_values slots aren't initialized when TopK > seq_len: in the branch
inside fast_topk_clusters_exact.cuh where you set output_indices[ind_offset + i]
= -1 for i >= seq_len, also set output_values[ind_offset + i] to a defined
sentinel (e.g., zero or -INF consistent with your API) so the padded tail
doesn't return garbage; update the same loop that uses output_values and
logits/logit_offset to initialize output_values in that else branch.

552-574: ⚠️ Potential issue | 🟡 Minor

Surface CUDA API failures from the launch helper.

cudaFuncSetAttribute and cudaLaunchKernelExC both return cudaError_t, but this helper ignores them. If the cluster launch or shared-memory opt-in fails, the later cudaGetLastError() in the binding won't pinpoint the failing call and can miss attribute failures entirely.

🔍 Minimal debug-visible handling
 inline void launch_topk_cluster_kernel(void* kernel, void** args, int grid_dim, int smem_bytes,
                                        int num_clusters, bool pdl_enabled, cudaStream_t stream) {
-  cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, MAX_SMEM_CARVEOUT);
-  cudaFuncSetAttribute(kernel, cudaFuncAttributePreferredSharedMemoryCarveout, 100);
+  auto err =
+      cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, MAX_SMEM_CARVEOUT);
+  assert(err == cudaSuccess);
+  err = cudaFuncSetAttribute(kernel, cudaFuncAttributePreferredSharedMemoryCarveout, 100);
+  assert(err == cudaSuccess);
 
   cudaLaunchConfig_t config;
   config.numAttrs = 0;
@@
   config.dynamicSmemBytes = smem_bytes;
   config.gridDim = grid_dim;
   config.stream = stream;
   config.attrs = attribute;
-  cudaLaunchKernelExC(&config, kernel, args);
+  err = cudaLaunchKernelExC(&config, kernel, args);
+  assert(err == cudaSuccess);
 }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/fast_topk_clusters_exact.cuh` around lines 552 - 574, The
helper launch_topk_cluster_kernel currently ignores return values from
cudaFuncSetAttribute and cudaLaunchKernelExC so failures are hidden; fix it by
making launch_topk_cluster_kernel return cudaError_t (instead of void), check
the cudaError_t result after each cudaFuncSetAttribute and after
cudaLaunchKernelExC, and immediately return the error on failure (or propagate
it) so callers can handle/log it; update callers to handle the returned
cudaError_t and propagate or log it accordingly. Use the existing symbols
cudaFuncSetAttribute, cudaLaunchKernelExC, and launch_topk_cluster_kernel to
locate and change the code.
include/flashinfer/topk_common.cuh (1)

4-7: ⚠️ Potential issue | 🟠 Major

Make topk_common.cuh self-contained.

RadixTopKTraits now names cuda::std::numeric_limits, half, and nv_bfloat16, but this header still only includes libc headers. That leaves it dependent on include order, and include/flashinfer/fast_topk_clusters_exact.cuh currently only brings in <cuda_fp16.h> before including it, so cuda::std and nv_bfloat16 are still unresolved here.

🔧 Minimal fix
+#include <cuda_bf16.h>
+#include <cuda_fp16.h>
+#include <cuda/std/limits>
+
 `#include` <cstdint>
 `#include` <cstdlib>
 `#include` <numeric>
 `#include` <type_traits>
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/topk_common.cuh` around lines 4 - 7, RadixTopKTraits in
topk_common.cuh relies on cuda::std::numeric_limits and the types half and
nv_bfloat16 but the header only includes libc headers; make topk_common.cuh
self-contained by adding the necessary includes (e.g., <limits> to provide
numeric_limits, <cuda_fp16.h> to provide half, and <cuda_bf16.h> to provide
nv_bfloat16) so RadixTopKTraits compiles independently of include order.
🧹 Nitpick comments (1)
include/flashinfer/fast_topk_clusters_exact.cuh (1)

209-215: Document why this branch spills instead of taking the cheaper alternatives.

The current comment explains the mechanism, but not why the exact path chooses global spill over dropping overflowed candidates or doing another pass. One sentence on that trade-off would make future tuning much safer. As per coding guidelines, "For performance-critical hot paths, leave comments with justification for special algorithmic choices and mention alternative approaches considered".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/fast_topk_clusters_exact.cuh` around lines 209 - 215, Add
a one-sentence justification above the spill branch explaining why the code
chooses to spill to the per-CTA global overflow cache (using
s_cached_overflow_count, overflow_stride, get_cached_overflow and writing
PackedCachedData) instead of cheaper alternatives like dropping overflowed
candidates or performing another pass; mention the trade-off: preserving
candidate correctness / avoiding additional kernel passes at the cost of a rare
global spill and minimal memory overhead, and note that alternatives were
considered but rejected due to increased error rate or extra
synchronization/latency.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@include/flashinfer/fast_topk_clusters_exact.cuh`:
- Around line 414-424: The padded output_values slots aren't initialized when
TopK > seq_len: in the branch inside fast_topk_clusters_exact.cuh where you set
output_indices[ind_offset + i] = -1 for i >= seq_len, also set
output_values[ind_offset + i] to a defined sentinel (e.g., zero or -INF
consistent with your API) so the padded tail doesn't return garbage; update the
same loop that uses output_values and logits/logit_offset to initialize
output_values in that else branch.
- Around line 552-574: The helper launch_topk_cluster_kernel currently ignores
return values from cudaFuncSetAttribute and cudaLaunchKernelExC so failures are
hidden; fix it by making launch_topk_cluster_kernel return cudaError_t (instead
of void), check the cudaError_t result after each cudaFuncSetAttribute and after
cudaLaunchKernelExC, and immediately return the error on failure (or propagate
it) so callers can handle/log it; update callers to handle the returned
cudaError_t and propagate or log it accordingly. Use the existing symbols
cudaFuncSetAttribute, cudaLaunchKernelExC, and launch_topk_cluster_kernel to
locate and change the code.

In `@include/flashinfer/topk_common.cuh`:
- Around line 4-7: RadixTopKTraits in topk_common.cuh relies on
cuda::std::numeric_limits and the types half and nv_bfloat16 but the header only
includes libc headers; make topk_common.cuh self-contained by adding the
necessary includes (e.g., <limits> to provide numeric_limits, <cuda_fp16.h> to
provide half, and <cuda_bf16.h> to provide nv_bfloat16) so RadixTopKTraits
compiles independently of include order.

---

Nitpick comments:
In `@include/flashinfer/fast_topk_clusters_exact.cuh`:
- Around line 209-215: Add a one-sentence justification above the spill branch
explaining why the code chooses to spill to the per-CTA global overflow cache
(using s_cached_overflow_count, overflow_stride, get_cached_overflow and writing
PackedCachedData) instead of cheaper alternatives like dropping overflowed
candidates or performing another pass; mention the trade-off: preserving
candidate correctness / avoiding additional kernel passes at the cost of a rare
global spill and minimal memory overhead, and note that alternatives were
considered but rejected due to increased error rate or extra
synchronization/latency.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: a6b88abe-9adf-4b60-9a41-e20f14433ba8

📥 Commits

Reviewing files that changed from the base of the PR and between a2e4e17 and 1badf7c.

📒 Files selected for processing (2)
  • include/flashinfer/fast_topk_clusters_exact.cuh
  • include/flashinfer/topk_common.cuh

Copy link
Copy Markdown
Member

@kahyunnam kahyunnam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @Aalanli . Will wait on merging until premerge passes and if @yzh119 wants to take another look

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

@Aalanli is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww

@kahyunnam
Copy link
Copy Markdown
Member

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !546 has been updated with latest changes, and the CI pipeline #48716096 is currently running. I'll report back once the pipeline job completes.

@Aalanli
Copy link
Copy Markdown
Contributor Author

Aalanli commented Apr 16, 2026

Hi @kahyunnam , I think the build failure is not due to the changes in this PR, do you know what's the issue?

@kahyunnam kahyunnam enabled auto-merge (squash) April 16, 2026 23:50
@kahyunnam
Copy link
Copy Markdown
Member

@Aalanli I think the build is just flakey, we can go ahead and merge. The "Test Results Summary" step is still ongoing but I've enabled automerge

auto-merge was automatically disabled April 19, 2026 21:15

Head branch was pushed to by a user without write access

@Aalanli
Copy link
Copy Markdown
Contributor Author

Aalanli commented Apr 20, 2026

Hi @kahyunnam, I found and fixed an edge case that manifested only some of the time in bfloat16 case (the last bin could be distributed in such a way that the threshold bin contains some topk values), the performance is still competitive. The other tests that previously failed were due to tests/utils/test_triton_cascade.py so I rebased on main.

Eg:

====================================================================================================
top_k: Basic radix-based top-k selection (dtype=FP32, deterministic=False, pattern=random)
NOTE: default top-k sweep includes two extra large-batch/long-vocab stress cases beyond the original grid
====================================================================================================
 batch    seq_len      k |   FlashInfer   torch.topk    Speedup     Clusters  Speedup Clusters vs. Default
-------------------------------------------------------------------------------------------------------------------
     1        256    256 |         4.32us      19.94us      4.62x       1.89us                         2.29x
     1        512    256 |         5.54us      21.50us      3.88x       5.86us                         0.95x
     1        512    512 |         5.25us      21.70us      4.13x       1.98us                         2.65x
     1       1024    256 |         8.58us      25.60us      2.99x       5.95us                         1.44x
     1       1024    512 |         6.53us      23.71us      3.63x       5.98us                         1.09x
     1       1024   1024 |         5.09us      24.85us      4.88x       2.11us                         2.41x
     1       2048    256 |         8.48us      29.66us      3.50x       6.43us                         1.32x
     1       2048    512 |        10.53us      30.91us      2.94x       6.62us                         1.59x
     1       2048   1024 |         6.46us      26.66us      4.12x       6.40us                         1.01x
     1       2048   2048 |         5.60us      28.54us      5.10x       2.53us                         2.21x
     1       4096    256 |         8.67us      35.33us      4.07x       6.56us                         1.32x
     1       4096    512 |        10.94us      36.80us      3.36x       6.62us                         1.65x
     1       4096   1024 |        11.04us      38.80us      3.51x       6.94us                         1.59x
     1       4096   2048 |         9.92us      35.04us      3.53x       6.91us                         1.44x
     1       4096   4096 |         6.88us      38.67us      5.62x       2.62us                         2.62x
     1      16384    256 |        12.13us      81.02us      6.68x      11.46us                         1.06x
     1      16384    512 |        14.50us      88.10us      6.08x      12.06us                         1.20x
     1      16384   1024 |        15.52us      86.14us      5.55x      12.06us                         1.29x
     1      16384   2048 |        16.96us      88.06us      5.19x      12.19us                         1.39x
     1      16384   4096 |        21.09us      90.88us      4.31x      12.48us                         1.69x
     1      65536    256 |        34.98us      90.29us      2.58x      12.42us                         2.82x
     1      65536    512 |        35.89us      88.40us      2.46x      12.61us                         2.85x
     1      65536   1024 |        37.63us      87.81us      2.33x      12.51us                         3.01x
     1      65536   2048 |        39.20us      88.32us      2.25x      13.54us                         2.90x
     1      65536   4096 |        40.35us      90.19us      2.24x      13.98us                         2.89x
     1     131072    256 |        41.02us      96.42us      2.35x      14.88us                         2.76x
     1     131072    512 |        41.76us      95.52us      2.29x      15.04us                         2.78x
     1     131072   1024 |        43.17us      97.60us      2.26x      15.01us                         2.88x

@kahyunnam kahyunnam enabled auto-merge (squash) April 20, 2026 15:52
@kahyunnam kahyunnam merged commit 44a2672 into flashinfer-ai:main Apr 20, 2026
70 of 85 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants